leecode更新
This commit is contained in:
parent
8d8994786c
commit
5de2a0889b
|
|
@ -0,0 +1,195 @@
|
|||
package com.markilue.leecode.backtrace;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @BelongsProject: Leecode
|
||||
* @BelongsPackage: com.markilue.leecode.backtrace
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-10-14 09:49
|
||||
* @Description: TODO 力扣40题 组合总和II:
|
||||
* 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。
|
||||
* candidates 中的每个数字在每个组合中只能使用 一次 。
|
||||
* 注意:解集不能包含重复的组合。
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class combinationSum2 {
|
||||
|
||||
@Test
|
||||
public void test(){
|
||||
int[] candidates = {10,1,2,7,6,1,5};
|
||||
int target = 8;
|
||||
System.out.println(combinationSum22(candidates, target));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test1(){
|
||||
int[] candidates = {2,5,2,1,2};
|
||||
int target = 5;
|
||||
System.out.println(combinationSum2(candidates, target));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test2(){
|
||||
int[] candidates = {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1};
|
||||
int target = 30;
|
||||
System.out.println(combinationSum21(candidates, target));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 回溯算法:这题的核心难点在于解集不能包含重复的组合,而候选者集合中会出现重复数字
|
||||
* 1)这里考虑先将候选者集合排序,在进行逻辑处理
|
||||
* 2)使用hashset去重
|
||||
*
|
||||
* @param candidates
|
||||
* @param target
|
||||
* @return
|
||||
*/
|
||||
public List<List<Integer>> combinationSum2(int[] candidates, int target) {
|
||||
Arrays.sort(candidates);
|
||||
backtracking(candidates, target, 0);
|
||||
ArrayList<List<Integer>> lists = new ArrayList<>();
|
||||
lists.addAll(result);
|
||||
|
||||
return lists;
|
||||
|
||||
}
|
||||
|
||||
List<List<Integer>> result = new ArrayList<>();
|
||||
HashSet<List<Integer>> result1 = new HashSet<>();
|
||||
List<Integer> cur = new ArrayList<>();
|
||||
int sum = 0;
|
||||
|
||||
//当重复元素较多时,直接超出时间限制
|
||||
public void backtracking(int[] candidates, int target, int val) {
|
||||
|
||||
if (sum == target) {
|
||||
cur.sort(new Comparator<Integer>() {
|
||||
@Override
|
||||
public int compare(Integer o1, Integer o2) {
|
||||
return o1-o2;
|
||||
}
|
||||
});
|
||||
result1.add(new ArrayList<>(cur));
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = val; i < candidates.length && sum + candidates[i] <= target; i++) {
|
||||
sum += candidates[i];
|
||||
cur.add(candidates[i]);
|
||||
backtracking(candidates, target, i + 1);
|
||||
cur.remove(cur.size() - 1);
|
||||
sum -= candidates[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 速度击败99.77%,内存击败52.35%
|
||||
* @param candidates
|
||||
* @param target
|
||||
* @return
|
||||
*/
|
||||
public List<List<Integer>> combinationSum21(int[] candidates, int target) {
|
||||
Arrays.sort(candidates);
|
||||
boolean[] used = new boolean[candidates.length];
|
||||
// System.out.println(Arrays.toString(used));
|
||||
backtracking1(candidates, target, 0,used);
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
//代码随想录方法回溯:在回溯过程中就可以回溯,元素在同一个组合内是可以重复的,
|
||||
// 但两个组合不能重复,所以需要去重同一树层上使用过的元素,同一树枝上不用去重
|
||||
public void backtracking1(int[] candidates, int target, int val,boolean[] used) {
|
||||
|
||||
if (sum == target) {
|
||||
result.add(new ArrayList<>(cur));
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = val; i < candidates.length && sum + candidates[i] <= target; i++) {
|
||||
|
||||
//如果used[i-1]==True,则说明同一树枝使用过candidates[i]
|
||||
//如果used[i-1]=false,则说明同一树层使用过candidates[i]
|
||||
//TODO 为什么同一树层使用过,就一定可以保证可以不在使用这个数?
|
||||
// 因为由于这个树层使用第2数时,实际上等同于这个树层使用第1数,下个树层使用第2数的子集,所以一定是包含关系,可以直接跳过
|
||||
if(i>0&&candidates[i]==candidates[i-1]&&used[i-1]==false){
|
||||
continue;
|
||||
}
|
||||
//使用以下判断,也可以进行跳过
|
||||
// if(i>val&&candidates[i]==candidates[i-1])
|
||||
|
||||
sum += candidates[i];
|
||||
cur.add(candidates[i]);
|
||||
used[i]=true;
|
||||
backtracking1(candidates, target, i + 1,used);
|
||||
used[i]=false;
|
||||
cur.remove(cur.size() - 1);
|
||||
sum -= candidates[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
List<int[]> freq = new ArrayList<int[]>();
|
||||
List<List<Integer>> ans = new ArrayList<List<Integer>>();
|
||||
List<Integer> sequence = new ArrayList<Integer>();
|
||||
|
||||
/**
|
||||
* 官方回溯方法:
|
||||
* 实际上是将问题转为同一个数能无限使用(最多使用rest/num=>most次)的组合
|
||||
* 与本人的combinationSum中的自己实现思路类似
|
||||
* @param candidates
|
||||
* @param target
|
||||
* @return
|
||||
*/
|
||||
public List<List<Integer>> combinationSum22(int[] candidates, int target) {
|
||||
Arrays.sort(candidates);
|
||||
|
||||
//这里是使用freq记录每个数字出现的次数
|
||||
for (int num : candidates) {
|
||||
int size = freq.size();
|
||||
if (freq.isEmpty() || num != freq.get(size - 1)[0]) {
|
||||
freq.add(new int[]{num, 1});
|
||||
} else {
|
||||
++freq.get(size - 1)[1];
|
||||
}
|
||||
}
|
||||
//开始递归回溯
|
||||
dfs(0, target);
|
||||
return ans;
|
||||
}
|
||||
|
||||
public void dfs(int pos, int rest) {
|
||||
//等于0,意味着找到了这个序列
|
||||
if (rest == 0) {
|
||||
ans.add(new ArrayList<Integer>(sequence));
|
||||
return;
|
||||
}
|
||||
//遍历到了最后,或者当前freq大于rest就返回
|
||||
if (pos == freq.size() || rest < freq.get(pos)[0]) {
|
||||
return;
|
||||
}
|
||||
//还不够就继续找
|
||||
dfs(pos + 1, rest);
|
||||
|
||||
//同一个数最多能使用多少次
|
||||
int most = Math.min(rest / freq.get(pos)[0], freq.get(pos)[1]);
|
||||
for (int i = 1; i <= most; ++i) {
|
||||
sequence.add(freq.get(pos)[0]);
|
||||
dfs(pos + 1, rest - i * freq.get(pos)[0]);
|
||||
}
|
||||
for (int i = 1; i <= most; ++i) {
|
||||
sequence.remove(sequence.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
package com.markilue.leecode.backtrace;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @BelongsProject: Leecode
|
||||
* @BelongsPackage: com.markilue.leecode.backtrace
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-10-14 11:40
|
||||
* @Description: TODO 力扣131题 分割回文串:
|
||||
* 给你一个字符串 s,请你将 s 分割成一些子串,使每个子串都是 回文串 。返回 s 所有可能的分割方案。
|
||||
* 回文串 是正着读和反着读都一样的字符串。
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class partition {
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
String s = "aab";
|
||||
System.out.println(partition1(s));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test1() {
|
||||
String s = "abba";
|
||||
System.out.println(partition1(s));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test2() {
|
||||
String s = "cdd";
|
||||
System.out.println(partition2(s));
|
||||
|
||||
}
|
||||
|
||||
|
||||
List<List<String>> result = new ArrayList<>();
|
||||
List<String> cur = new ArrayList<>();
|
||||
|
||||
|
||||
/**
|
||||
* 自己的思路:分别按不同长度对s进行划分,不是回溯法
|
||||
* 存在问题:同同一次划分时的长度可能不一样
|
||||
*
|
||||
* @param s
|
||||
* @return
|
||||
*/
|
||||
public List<List<String>> partition(String s) {
|
||||
//按i对s进行划分
|
||||
for (int i = 1; i <= s.length(); i++) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
for (int j = 0; j < s.length(); j++) {
|
||||
if (j != 0 && j % i == 0) {
|
||||
cur.add(builder.toString());
|
||||
builder.delete(0, builder.length());
|
||||
builder.append(s.charAt(j));
|
||||
} else {
|
||||
builder.append(s.charAt(j));
|
||||
}
|
||||
|
||||
}
|
||||
//添加最后的builder内容
|
||||
if (builder.length() != 0) {
|
||||
cur.add(builder.toString());
|
||||
}
|
||||
backtracking(s);
|
||||
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
public void backtracking(String s) {
|
||||
|
||||
for (String s1 : cur) {
|
||||
if (!isPalindrome(s1)) {
|
||||
cur.clear();
|
||||
} else {
|
||||
result.add(new ArrayList<String>(cur));
|
||||
cur.clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
public boolean isPalindrome(String s) {
|
||||
char[] chars = s.toCharArray();
|
||||
for (int i = 0; i < chars.length / 2; i++) {
|
||||
if (chars[i] != chars[chars.length - i - 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 自己的思路2:分别按不同长度对s进行划分,不是回溯法
|
||||
* 速度击败77.01%,内存击败23.77%
|
||||
*
|
||||
* @param s
|
||||
* @return
|
||||
*/
|
||||
public List<List<String>> partition1(String s) {
|
||||
|
||||
backtracking1(s);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
public void backtracking1(String s) {
|
||||
|
||||
if (s.length() == 1) {
|
||||
cur.add(s);
|
||||
result.add(new ArrayList<String>(cur));
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 1; i < s.length(); i++) {
|
||||
//按长度分割
|
||||
String str1 = s.substring(0, i);
|
||||
|
||||
cur.add(str1);
|
||||
|
||||
if (!isPalindrome(str1)) {
|
||||
cur.remove(cur.size() - 1);
|
||||
continue;
|
||||
}
|
||||
|
||||
String str2 = s.substring(i, s.length());
|
||||
backtracking1(str2);
|
||||
cur.remove(cur.size() - 1);
|
||||
cur.remove(cur.size() - 1);
|
||||
}
|
||||
|
||||
cur.add(s);
|
||||
if (isPalindrome(s)) {
|
||||
|
||||
result.add(new ArrayList<String>(cur));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 代码随想录思路:分别按不同长度对s进行划分,不是回溯法,在代码上进行了优化,当前是回文再加入,避免了删除两次等逻辑
|
||||
* 速度击败16.16%,内存击败26.93%
|
||||
*
|
||||
* @param s
|
||||
* @return
|
||||
*/
|
||||
public List<List<String>> partition2(String s) {
|
||||
|
||||
backtracking2(s,0);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
public void backtracking2(String s,int start) {
|
||||
|
||||
if (start>=s.length()) {
|
||||
result.add(new ArrayList<String>(cur));
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = start; i < s.length(); i++) {
|
||||
//按长度分割
|
||||
if (isPalindrome1(s,start,i)) {
|
||||
String str1 = s.substring(start, i+1);
|
||||
cur.add(str1);
|
||||
}else {
|
||||
continue;
|
||||
}
|
||||
backtracking2(s,i+1);
|
||||
cur.remove(cur.size() - 1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public boolean isPalindrome1(String s, int start, int end) {
|
||||
char[] chars = s.toCharArray();
|
||||
for (int i = start, j = end; i < j; i++, j--) {
|
||||
if (chars[i] != chars[j]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -22,6 +22,7 @@ from model.Joint_Monitoring.compare.RNet import Joint_Monitoring
|
|||
from model.CommonFunction.CommonFunction import *
|
||||
from sklearn.model_selection import train_test_split
|
||||
from tensorflow.keras.models import load_model, save_model
|
||||
import random
|
||||
|
||||
'''超参数设置'''
|
||||
time_stamp = 120
|
||||
|
|
@ -35,11 +36,11 @@ K = 18
|
|||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./model/weight/{0}_timestamp{1}_feature{2}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_name = "./model/weight/{0}_timestamp{1}_feature{2}_weight_epoch2_loss0.007/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_timestamp{1}_feature{2}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
|
|
@ -234,53 +235,6 @@ def EWMA(data, K=K, namuda=namuda):
|
|||
pass
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model):
|
||||
predicted_data = new_model.predict(data)
|
||||
|
||||
temp = np.abs(predicted_data - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
|
||||
dims, = mse.shape
|
||||
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
|
||||
# plt.plot(max)
|
||||
# plt.plot(mse)
|
||||
# plt.plot(mean)
|
||||
# # plt.plot(min)
|
||||
# plt.show()
|
||||
#
|
||||
#
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
def condition_monitoring_model():
|
||||
input = tf.keras.Input(shape=[time_stamp, feature_num])
|
||||
conv1 = tf.keras.layers.Conv1D(filters=256, kernel_size=1)(input)
|
||||
GRU1 = tf.keras.layers.GRU(128, return_sequences=False)(conv1)
|
||||
d1 = tf.keras.layers.Dense(300)(GRU1)
|
||||
output = tf.keras.layers.Dense(10)(d1)
|
||||
|
||||
model = tf.keras.Model(inputs=input, outputs=output)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# trian_data:(300455,120,10)
|
||||
# trian_label1:(300455,10)
|
||||
# trian_label2:(300455,)
|
||||
|
|
@ -523,6 +477,94 @@ def showResult(step_two_model: Joint_Monitoring, test_data, isPlot: bool = False
|
|||
return total_result
|
||||
|
||||
|
||||
def get_MSE(data, label, new_model, isStandard: bool = True, isPlot: bool = True):
|
||||
predicted_data1 = []
|
||||
predicted_data2 = []
|
||||
predicted_data3 = []
|
||||
size, length, dims = data.shape
|
||||
for epoch in range(0, size - batch_size + 1, batch_size):
|
||||
each_test_data = data[epoch:epoch + batch_size, :, :]
|
||||
output1, output2, output3, _ = new_model.call(inputs=each_test_data, is_first_time=True)
|
||||
predicted_data1.append(output1)
|
||||
predicted_data2.append(output2)
|
||||
predicted_data3.append(output3)
|
||||
|
||||
predicted_data1 = np.reshape(predicted_data1, [-1, 10])
|
||||
predicted_data2 = np.reshape(predicted_data2, [-1, 10])
|
||||
predicted_data3 = np.reshape(predicted_data3, [-1, 10])
|
||||
temp = np.abs(predicted_data1 - label)
|
||||
temp1 = (temp - np.broadcast_to(np.mean(temp, axis=0), shape=predicted_data1.shape))
|
||||
temp2 = np.broadcast_to(np.sqrt(np.var(temp, axis=0)), shape=predicted_data1.shape)
|
||||
temp3 = temp1 / temp2
|
||||
mse = np.sum((temp1 / temp2) ** 2, axis=1)
|
||||
print("z:", mse)
|
||||
print(mse.shape)
|
||||
|
||||
# mse=np.mean((predicted_data-label)**2,axis=1)
|
||||
print("mse", mse)
|
||||
if isStandard:
|
||||
dims, = mse.shape
|
||||
mean = np.mean(mse)
|
||||
std = np.sqrt(np.var(mse))
|
||||
max = mean + 3 * std
|
||||
print("max:", max)
|
||||
# min = mean-3*std
|
||||
max = np.broadcast_to(max, shape=[dims, ])
|
||||
# min = np.broadcast_to(min,shape=[dims,])
|
||||
mean = np.broadcast_to(mean, shape=[dims, ])
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(max)
|
||||
plt.plot(mse)
|
||||
plt.plot(mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
else:
|
||||
if isPlot:
|
||||
plt.figure(random.randint(1, 9))
|
||||
plt.plot(mse)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
return mse
|
||||
|
||||
return mse, mean, max
|
||||
# pass
|
||||
|
||||
|
||||
# healthy_data是健康数据,用于确定阈值,all_data是完整的数据,用于模型出结果
|
||||
def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data, unhealthy_label, isPlot: bool = False,
|
||||
isSave: bool = False):
|
||||
# TODO 计算MSE确定阈值
|
||||
|
||||
mse, mean, max = get_MSE(healthy_data, healthy_label, model)
|
||||
|
||||
# 误报率的计算
|
||||
total, = mse.shape
|
||||
faultNum = 0
|
||||
faultList = []
|
||||
for i in range(total):
|
||||
if (mse[i] > max[i]):
|
||||
faultNum += 1
|
||||
faultList.append(mse[i])
|
||||
|
||||
fault_rate = faultNum / total
|
||||
print("误报率:", fault_rate)
|
||||
|
||||
# 漏报率计算
|
||||
missNum = 0
|
||||
missList = []
|
||||
mse1 = get_MSE(unhealthy_data, unhealthy_label, model, isStandard=False)
|
||||
all, = mse1.shape
|
||||
for i in range(all):
|
||||
if (mse1[i] < max[0]):
|
||||
missNum += 1
|
||||
missList.append(mse1[i])
|
||||
|
||||
miss_rate = missNum / all
|
||||
print("漏报率:", miss_rate)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
total_data = loadData.execute(N=feature_num, file_name=file_name)
|
||||
total_data = normalization(data=total_data)
|
||||
|
|
@ -535,38 +577,26 @@ if __name__ == '__main__':
|
|||
# 单次测试
|
||||
# train_step_one(train_data=train_data_healthy[:256, :, :], train_label1=train_label1_healthy[:256, :],
|
||||
# train_label2=train_label2_healthy[:256, ])
|
||||
|
||||
train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
#### 模型训练
|
||||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
# step_one_model.load_weights(save_name)
|
||||
step_one_model = Joint_Monitoring()
|
||||
step_one_model.load_weights(save_name)
|
||||
#
|
||||
# step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
# #### TODO 第二步训练
|
||||
# ### healthy_data.shape: (300333,120,10)
|
||||
# ### unhealthy_data.shape: (16594,10)
|
||||
# healthy_size, _, _ = train_data_healthy.shape
|
||||
# unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
# # train_data, train_label1, train_label2, test_data, test_label1, test_label2 = split_test_data(
|
||||
# # healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :, :],
|
||||
# # healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
# # healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
# # unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# # train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# # train_data=train_data,
|
||||
# # train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
#
|
||||
# ### TODO 测试测试集
|
||||
# step_one_model = Joint_Monitoring()
|
||||
# step_one_model.load_weights(save_name)
|
||||
# step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_step_two_name)
|
||||
# # test(step_one_model=step_one_model, step_two_model=step_two_model, test_data=test_data, test_label1=test_label1,
|
||||
# # test_label2=np.expand_dims(test_label2, axis=-1))
|
||||
#
|
||||
# #### TODO 计算MSE
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
getResult(step_one_model, healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
unhealthy_data=train_data_unhealthy, unhealthy_label=train_label1_unhealthy)
|
||||
|
||||
# ###TODO 展示全部的结果
|
||||
# all_data, _, _ = get_training_data_overlapping(
|
||||
# total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue