leecode更新

This commit is contained in:
markilue 2022-11-08 22:02:46 +08:00
parent 1cf7ba2aa3
commit e18305c54a
4 changed files with 148 additions and 23 deletions

View File

@ -0,0 +1,75 @@
package com.markilue.leecode.greedy;
import org.junit.Test;
/**
* @BelongsProject: Leecode
* @BelongsPackage: com.markilue.leecode.greedy
* @Author: markilue
* @CreateTime: 2022-11-08 20:16
* @Description: TODO 力扣134题 加油站
* 在一条环路上有 n 个加油站其中第 i 个加油站有汽油 gas[i]
* 你有一辆油箱容量无限的的汽车从第 i 个加油站开往第 i+1 个加油站需要消耗汽油 cost[i] 你从其中的一个加油站出发开始时油箱为空
* 给定两个整数数组 gas cost 如果你可以绕环路行驶一周则返回出发时加油站的编号否则返回 -1 如果存在解 保证 它是 唯一
* @Version: 1.0
*/
public class CanCompleteCircuit {
@Test
public void test() {
int[] gas = {1, 2, 3, 4, 5}, cost = {3, 4, 5, 1, 2};
System.out.println(canCompleteCircuit(gas, cost));
}
@Test
public void test1() {
int[] gas = {2,3,4}, cost = {3,4,3};
System.out.println(canCompleteCircuit(gas, cost));
}
/**
* 本人思路可以先通过gas数组里的值判断是否大于cost值
* 如果大于需要继续判断累加值如果开始累加小于则前面全部弃用从头开始
* 速度击败78.2%内存击败77.8%
* @param gas
* @param cost
* @return
*/
public int canCompleteCircuit(int[] gas, int[] cost) {
int length = gas.length;
int total = 0;
int sum =0 ;
int flag =0;
boolean f=false;
for (int i = 0; i < length; i++) {
total+=gas[i]-cost[i];
if (sum == 0 && gas[i] < cost[i]) {
continue;
}
sum+=gas[i]-cost[i];
if(!f&&sum>0){
flag=i;
f=true;
}
if(sum<=0){
sum=0;
f=false;
}
}
if(total>=0){
return flag;
}else {
return -1;
}
}
}

View File

@ -548,9 +548,13 @@ def plot_hot_one(data):
pass pass
def plot_mse(file_name="../others_idea/mse"): def plot_mse(file_name_mse="../others_idea/mse",data:str=''):
mse = np.loadtxt(file_name, delimiter=",") mse = np.loadtxt(file_name_mse, delimiter=",")
print(mse.shape) raw_data=np.loadtxt(data,delimiter=",")
raw_data=raw_data[:,:mse.shape[1]]
print("mse:",mse.shape)
print("raw_data:",raw_data.shape)
res=raw_data-mse
mse.shape[0]*2/3 mse.shape[0]*2/3
# mse = mse[2000:2300] # mse = mse[2000:2300]
# mse = mse[1800:2150] # mse = mse[1800:2150]
@ -566,20 +570,25 @@ def plot_mse(file_name="../others_idea/mse"):
'legend.fontsize': 5, 'legend.fontsize': 5,
} }
plt.rcParams.update(parameters) plt.rcParams.update(parameters)
plt.figure() for i in range(mse.shape[0]):
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵 plt.figure()
indices = [mse.shape[0] * i / 4 for i in range(5)] plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
classes = ['13/09/17', '14/09/17', '15/09/17', '16/09/17', '17/09/17'] indices = [mse.shape[0] * i / 4 for i in range(5)]
classes = ['13/09/17', '14/09/17', '15/09/17', '16/09/17', '17/09/17']
# plt.xticks([index + 0.5 for index in indices], classes, rotation=25) # 设置横坐标方向rotation=45为45度倾斜
plt.ylabel('MSE', fontsize=8)
plt.xlabel('Time', fontsize=8)
plt.tight_layout()
plt.axvline(res.shape[1] * 2 / 3, c='purple', ls='-.', lw=0.5, label="real fault")
plt.plot(res[i, :], lw=0.5)
# plt.xticks([index + 0.5 for index in indices], classes, rotation=25) # 设置横坐标方向rotation=45为45度倾斜
plt.ylabel('MSE', fontsize=8)
plt.xlabel('Time', fontsize=8)
plt.tight_layout()
plt.plot(mse, lw=0.5)
plt.show() plt.show()
def plot_3d(): def plot_3d():
# 线 # 线
fig = plt.figure() fig = plt.figure()
@ -728,9 +737,9 @@ def test_model_visualization(model_name=file_name):
if __name__ == '__main__': if __name__ == '__main__':
test_mse() # test_mse(fi)
# test_result( # test_result(
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_banda_result1.csv') # file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_banda_result3.csv')
# test_corr() # test_corr()
# acc() # acc()
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02] # list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
@ -739,8 +748,8 @@ if __name__ == '__main__':
# list=[99.99,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74] # list=[99.99,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
# plot_FNR1(list) # plot_FNR1(list)
# # # #
# list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6] list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
# plot_FNR2(list) plot_FNR2(list)
# 查看网络某一层的权重 # 查看网络某一层的权重
# test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt") # test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt")
@ -750,7 +759,14 @@ if __name__ == '__main__':
# test_model_weight_l(file_name) # test_model_weight_l(file_name)
# 单独预测图 # 单独预测图
# plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D/banda\RNet_D_banda_mse_predict1.csv') plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
#画3d图 #画3d图
# plot_3d() # plot_3d()
#原始数据图
# file_names='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv'
# data= np.loadtxt(file_names,delimiter=',')
# print(data.shape)

View File

@ -57,7 +57,7 @@ save_step_two_name = "../hard_model/two_weight/{0}_epoch24_9875_9867/weight".for
# batch_size, # batch_size,
# EPOCH) # EPOCH)
save_mse_name=r"./compare/mse/JM_banda/{0}_mse.csv".format(model_name) save_mse_name=r"./compare/mse/JM_banda/{0}_result.csv".format(model_name)
'''文件名''' '''文件名'''
file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv" file_name = "G:\data\SCADA数据\SCADA_已处理_粤水电达坂城2020.1月-5月\风机15.csv"
@ -544,9 +544,40 @@ def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False
return total_result return total_result
def show_trend(step_two_model: Joint_Monitoring, test_data,isSave:bool=True):
# 获取模型的所有参数的个数
# step_two_model.count_params()
size, length, dims = test_data.shape
for epoch in range(0, size - batch_size + 1, batch_size):
each_test_data = test_data[epoch:epoch + batch_size, :, :]
output1, output2, output3, _ = step_two_model.call(each_test_data, is_first_time=True)
if epoch==0:
result1=output1
result2=output2
result3=output3
else:
result1=np.concatenate([result1,output1],axis=0)
result2=np.concatenate([result2,output2],axis=0)
result3=np.concatenate([result3,output3],axis=0)
# 转置
result1=np.transpose(result1,(1,0))
result2=np.transpose(result2,(1,0))
result3=np.transpose(result3,(1,0))
# 误报率,漏报率,准确性的计算
if isSave:
for data,j in zip([result1,result2,result3],range(3)):
save_mse_name1 = save_mse_name[:-4] + "_predict" + str(j + 1) + ".csv"
np.savetxt(save_mse_name1, data, delimiter=',')
pass
if __name__ == '__main__': if __name__ == '__main__':
total_data = loadData.execute(N=feature_num, file_name=file_name) total_data = loadData.execute(N=feature_num, file_name=file_name)
total_data = normalization(data=total_data) total_data = normalization(data=total_data)
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping( train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
total_data[:healthy_date, :], is_Healthy=True) total_data[:healthy_date, :], is_Healthy=True)
train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping( train_data_unhealthy, train_label1_unhealthy, train_label2_unhealthy = get_training_data_overlapping(
@ -587,11 +618,14 @@ if __name__ == '__main__':
# test_label2=np.expand_dims(test_label2, axis=-1)) # test_label2=np.expand_dims(test_label2, axis=-1))
###TODO 展示全部的结果 ###TODO 展示全部的结果
all_data, _, _ = get_training_data_overlapping( # all_data, _, _ = get_training_data_overlapping(
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True) # total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
save_data=np.transpose(total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :])
np.savetxt('./compare/mse/JM_banda/raw_data.csv',save_data ,delimiter=',')
# all_data = np.concatenate([]) # all_data = np.concatenate([])
# 单次测试 # 单次测试
# showResult(step_two_model, test_data=all_data[:32], isPlot=True) # showResult(step_two_model, test_data=all_data[:32], isPlot=True)
showResult(step_two_model, test_data=all_data, isPlot=True) # show_trend(step_two_model,all_data,isSave=True)
# showResult(step_two_model, test_data=all_data, isPlot=True)
pass pass

View File

@ -42,7 +42,7 @@ save_max_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_max.csv".format(model_
feature_num, feature_num,
batch_size, batch_size,
EPOCH) EPOCH)
save_mse_name1 = "./mse/ResNet/{0}_result1.csv".format(model_name, save_mse_name1 = "./mse/ResNet/{0}_result3.csv".format(model_name,
time_stamp, time_stamp,
feature_num, feature_num,
batch_size, batch_size,
@ -395,7 +395,7 @@ if __name__ == '__main__':
# history = model.fit(train_data, train_label, epochs=20, batch_size=16, validation_data=(test_data, test_label), # history = model.fit(train_data, train_label, epochs=20, batch_size=16, validation_data=(test_data, test_label),
# callbacks=[checkpoint, early_stop]) # callbacks=[checkpoint, early_stop])
# model.save("./model/ResNet.h5") # model.save("./model/ResNet.h5")
model = tf.keras.models.load_model("model/ResNet_banda/ResNet_banda_epoch12_9942.h5") model = tf.keras.models.load_model("model/ResNet_banda/ResNet_banda_epoch10_9884.h5")
# 结果展示 # 结果展示