leecode更新
This commit is contained in:
parent
10c949ab6c
commit
88a7aefa97
|
|
@ -0,0 +1,71 @@
|
|||
package com.markilue.leecode.greedy;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
|
||||
/**
|
||||
* @BelongsProject: Leecode
|
||||
* @BelongsPackage: com.markilue.leecode.greedy
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-10-24 09:34
|
||||
* @Description:
|
||||
* TODO 力扣455题 分发饼干:
|
||||
* 假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。
|
||||
* 对每个孩子 i,都有一个胃口值 g[i],这是能让孩子们满足胃口的饼干的最小尺寸;
|
||||
* 并且每块饼干 j,都有一个尺寸 s[j] 。如果 s[j] >= g[i],我们可以将这个饼干 j 分配给孩子 i ,这个孩子会得到满足。
|
||||
* 你的目标是尽可能满足越多数量的孩子,并输出这个最大数值。
|
||||
*
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class FindContentChildren {
|
||||
|
||||
|
||||
@Test
|
||||
public void test(){
|
||||
int[] g= {1,2,3};
|
||||
int[] s={1,1};
|
||||
System.out.println(findContentChildren(g,s));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test1(){
|
||||
int[] g= {1,2};
|
||||
int[] s={1,2,3};
|
||||
System.out.println(findContentChildren(g,s));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 取饼干最大的取满足孩子最大的,如果满足不了,就一直找到满足的孩子为止,如果找不到了就直接返回;;饼干找完了也返回
|
||||
* 速度击败10.72%,内存击败5.17%
|
||||
* @param g
|
||||
* @param s
|
||||
* @return
|
||||
*/
|
||||
public int findContentChildren(int[] g, int[] s) {
|
||||
|
||||
Arrays.sort(g);
|
||||
Arrays.sort(s);
|
||||
|
||||
int result = 0;
|
||||
int j=s.length-1;
|
||||
|
||||
for (int i = g.length-1; i >= 0; i--) {
|
||||
|
||||
if (j==-1) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if(s[j]>=g[i]){
|
||||
j--;
|
||||
result++;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
package com.markilue.leecode.greedy;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* @BelongsProject: Leecode
|
||||
* @BelongsPackage: com.markilue.leecode.greedy
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-10-24 10:26
|
||||
* @Description: TODO 力扣376题 摆动序列:
|
||||
* 如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。第一个差(如果存在的话)可能是正数或负数。仅有一个元素或者含两个不等元素的序列也视作摆动序列。
|
||||
* 例如, [1, 7, 4, 9, 2, 5] 是一个 摆动序列 ,因为差值 (6, -3, 5, -7, 3) 是正负交替出现的。
|
||||
* 相反,[1, 4, 7, 2, 5] 和 [1, 7, 4, 5, 5] 不是摆动序列,第一个序列是因为它的前两个差值都是正数,第二个序列是因为它的最后一个差值为零。
|
||||
* 子序列 可以通过从原始序列中删除一些(也可以不删除)元素来获得,剩下的元素保持其原始顺序。
|
||||
* 给你一个整数数组 nums ,返回 nums 中作为 摆动序列 的 最长子序列的长度
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class WiggleMaxLength {
|
||||
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
int[] nums = {1, 7, 4, 9, 2, 5};
|
||||
System.out.println(wiggleMaxLength1(nums));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test1() {
|
||||
int[] nums = {1, 17, 5, 10, 13, 15, 10, 5, 16, 8};
|
||||
System.out.println(wiggleMaxLength2(nums));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test2() {
|
||||
int[] nums = {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
System.out.println(wiggleMaxLength2(nums));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test3() {
|
||||
int[] nums = {9, 8, 7, 6, 5, 4, 3, 2, 1};
|
||||
System.out.println(wiggleMaxLength1(nums));
|
||||
}
|
||||
|
||||
|
||||
public int wiggleMaxLength(int[] nums) {
|
||||
|
||||
|
||||
int result = 0;
|
||||
|
||||
boolean flag = true; //true就找比现在大的
|
||||
for (int i = 0; i < nums.length; i++) {
|
||||
|
||||
//找比当前大的
|
||||
if (flag && i > 0 && nums[i] > nums[i - 1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//找比当前小的
|
||||
if (!flag && i > 0 && nums[i] < nums[i - 1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//无论是从哪个出来的,找到了一个峰或者谷
|
||||
result++;
|
||||
|
||||
//如果是找大的出来的
|
||||
if (flag) {
|
||||
//找到一个比他小的
|
||||
flag = false;
|
||||
} else {
|
||||
flag = true;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
flag = false; //true就找比现在大的
|
||||
int result1 = 0;
|
||||
for (int i = 0; i < nums.length; i++) {
|
||||
|
||||
//找比当前大的
|
||||
if (flag && i > 0 && nums[i] > nums[i - 1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//找比当前小的
|
||||
if (!flag && i > 0 && nums[i] < nums[i - 1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//无论是从哪个出来的,找到了一个峰或者谷
|
||||
result1++;
|
||||
|
||||
//如果是找大的出来的
|
||||
if (flag) {
|
||||
//找到一个比他小的
|
||||
flag = false;
|
||||
} else {
|
||||
flag = true;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
return result > result1 ? result : result1;
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 代码随想录方法
|
||||
* 速度击败100%,内存击败54.73%
|
||||
* @param nums
|
||||
* @return
|
||||
*/
|
||||
public int wiggleMaxLength1(int[] nums) {
|
||||
|
||||
|
||||
int result = 1;
|
||||
|
||||
int curdiff = 0; //记录现在的差值
|
||||
int prediff = 0; //记录上次的差值
|
||||
for (int i = 0; i < nums.length-1; i++) {
|
||||
|
||||
curdiff = nums[i + 1] - nums[i];
|
||||
|
||||
if ((prediff >= 0 && curdiff < 0) || (prediff <= 0 && curdiff > 0)) {
|
||||
result++;
|
||||
prediff = curdiff;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 官方动态规划法:
|
||||
* 注意到每有一个「峰」到「谷」的下降趋势,down 值才会增加,每有一个「谷」到「峰」的上升趋势,up 值才会增加。
|
||||
* 且过程中 down 与 up 的差的绝对值值恒不大于 1,即 up≤down+1 且down≤up+1,
|
||||
* 于是有max(up,down+1)=down+1 且 max(up+1,down)=up+1。这样我们可以省去不必要的比较大小的过程。
|
||||
* @return
|
||||
*/
|
||||
public int wiggleMaxLength2(int[] nums) {
|
||||
|
||||
int n = nums.length;
|
||||
if (n < 2) {
|
||||
return n;
|
||||
}
|
||||
int[] up = new int[n]; //记录上升摆动序列的最大长度
|
||||
int[] down = new int[n]; //记录下降摆动序列的最大长度
|
||||
up[0] = down[0] = 1;
|
||||
for (int i = 1; i < n; i++) {
|
||||
if (nums[i] > nums[i - 1]) {
|
||||
//如果当前数比之前的数大,那么下降摆动序列就遇上了峰值,这时候就记录上升序列和下降序列+1谁大
|
||||
up[i] = Math.max(up[i - 1], down[i - 1] + 1);
|
||||
down[i] = down[i - 1];
|
||||
} else if (nums[i] < nums[i - 1]) {
|
||||
up[i] = up[i - 1];
|
||||
down[i] = Math.max(up[i - 1] + 1, down[i - 1]);
|
||||
} else {
|
||||
up[i] = up[i - 1];
|
||||
down[i] = down[i - 1];
|
||||
}
|
||||
}
|
||||
return Math.max(up[n - 1], down[n - 1]);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 官方优化的动态规划法:注意到方法一中,我们仅需要前一个状态来进行转移,所以我们维护两个变量即可。
|
||||
* @param nums
|
||||
* @return
|
||||
*/
|
||||
public int wiggleMaxLength3(int[] nums) {
|
||||
int n = nums.length;
|
||||
if (n < 2) {
|
||||
return n;
|
||||
}
|
||||
int up = 1, down = 1;
|
||||
for (int i = 1; i < n; i++) {
|
||||
if (nums[i] > nums[i - 1]) {
|
||||
up = Math.max(up, down + 1);
|
||||
} else if (nums[i] < nums[i - 1]) {
|
||||
down = Math.max(up + 1, down);
|
||||
}
|
||||
}
|
||||
return Math.max(up, down);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -31,7 +31,7 @@ feature_num = 10
|
|||
batch_size = 16
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "RNet_C"
|
||||
model_name = "RNet_3"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
|
@ -48,12 +48,12 @@ save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".fo
|
|||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
save_mse_name = "./mse/RNet_3/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
save_max_name = "./mse/RNet_3/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -662,10 +662,10 @@ if __name__ == '__main__':
|
|||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
# # step_one_model.load_weights(save_name)
|
||||
# #
|
||||
# step_two_model = Joint_Monitoring()
|
||||
step_one_model = Joint_Monitoring()
|
||||
# step_one_model.load_weights(save_name)
|
||||
#
|
||||
step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
#### TODO 第二步训练
|
||||
|
|
@ -678,9 +678,9 @@ if __name__ == '__main__':
|
|||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ feature_num = 10
|
|||
batch_size = 16
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "RNet_C"
|
||||
model_name = "RNet_34"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
|
@ -42,18 +42,18 @@ save_name = "./model/weight/{0}/weight".format(model_name,
|
|||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
|
||||
save_step_two_name = "./model/two_weight/{0}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
save_mse_name = "./mse/RNet_34/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
save_max_name = "./mse/RNet_34/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -662,10 +662,10 @@ if __name__ == '__main__':
|
|||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
step_one_model = Joint_Monitoring()
|
||||
# # step_one_model.load_weights(save_name)
|
||||
# #
|
||||
# step_two_model = Joint_Monitoring()
|
||||
step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
#### TODO 第二步训练
|
||||
|
|
@ -678,9 +678,9 @@ if __name__ == '__main__':
|
|||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ feature_num = 10
|
|||
batch_size = 16
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "RNet_C"
|
||||
model_name = "RNet_35"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
|
@ -42,18 +42,18 @@ save_name = "./model/weight/{0}/weight".format(model_name,
|
|||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
|
||||
save_step_two_name = "./model/two_weight/{0}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
save_mse_name = "./mse/RNet_35/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
save_max_name = "./mse/RNet_35/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -662,10 +662,10 @@ if __name__ == '__main__':
|
|||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
step_one_model = Joint_Monitoring()
|
||||
# # step_one_model.load_weights(save_name)
|
||||
# #
|
||||
# step_two_model = Joint_Monitoring()
|
||||
step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
#### TODO 第二步训练
|
||||
|
|
@ -678,9 +678,9 @@ if __name__ == '__main__':
|
|||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ feature_num = 10
|
|||
batch_size = 16
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "RNet_C"
|
||||
model_name = "RNet_4"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
|
@ -42,7 +42,7 @@ save_name = "./model/weight/{0}/weight".format(model_name,
|
|||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
|
||||
save_step_two_name = "./model/two_weight/{0}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -662,10 +662,10 @@ if __name__ == '__main__':
|
|||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
step_one_model = Joint_Monitoring()
|
||||
# # step_one_model.load_weights(save_name)
|
||||
# #
|
||||
# step_two_model = Joint_Monitoring()
|
||||
step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
#### TODO 第二步训练
|
||||
|
|
@ -678,9 +678,9 @@ if __name__ == '__main__':
|
|||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ feature_num = 10
|
|||
batch_size = 16
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "RNet_C"
|
||||
model_name = "RNet_45"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
|
@ -42,18 +42,18 @@ save_name = "./model/weight/{0}/weight".format(model_name,
|
|||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
|
||||
save_step_two_name = "./model/two_weight/{0}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
save_mse_name = "./mse/RNet_45/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
save_max_name = "./mse/RNet_45/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -662,10 +662,10 @@ if __name__ == '__main__':
|
|||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
step_one_model = Joint_Monitoring()
|
||||
# # step_one_model.load_weights(save_name)
|
||||
# #
|
||||
# step_two_model = Joint_Monitoring()
|
||||
step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
#### TODO 第二步训练
|
||||
|
|
@ -678,9 +678,9 @@ if __name__ == '__main__':
|
|||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ feature_num = 10
|
|||
batch_size = 16
|
||||
learning_rate = 0.001
|
||||
EPOCH = 101
|
||||
model_name = "RNet_C"
|
||||
model_name = "RNet_5"
|
||||
'''EWMA超参数'''
|
||||
K = 18
|
||||
namuda = 0.01
|
||||
|
|
@ -42,7 +42,7 @@ save_name = "./model/weight/{0}/weight".format(model_name,
|
|||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
|
||||
save_step_two_name = "./model/two_weight/{0}_weight/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -662,10 +662,10 @@ if __name__ == '__main__':
|
|||
# train_step_one(train_data=train_data_healthy, train_label1=train_label1_healthy, train_label2=train_label2_healthy)
|
||||
|
||||
# 导入第一步已经训练好的模型,一个继续训练,一个只输出结果
|
||||
# step_one_model = Joint_Monitoring()
|
||||
step_one_model = Joint_Monitoring()
|
||||
# # step_one_model.load_weights(save_name)
|
||||
# #
|
||||
# step_two_model = Joint_Monitoring()
|
||||
step_two_model = Joint_Monitoring()
|
||||
# step_two_model.load_weights(save_name)
|
||||
|
||||
#### TODO 第二步训练
|
||||
|
|
@ -678,9 +678,9 @@ if __name__ == '__main__':
|
|||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
|
||||
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
|
||||
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
# train_data=train_data,
|
||||
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
|
||||
train_data=train_data,
|
||||
train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))
|
||||
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
|
|
|
|||
|
|
@ -673,8 +673,8 @@ if __name__ == '__main__':
|
|||
#### TODO 第二步训练
|
||||
### healthy_data.shape: (300333,120,10)
|
||||
### unhealthy_data.shape: (16594,10)
|
||||
# healthy_size, _, _ = train_data_healthy.shape
|
||||
# unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
train_data, train_label1, train_label2, test_data, test_label1, test_label2 = split_test_data(
|
||||
healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :, :],
|
||||
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
|
||||
|
|
|
|||
|
|
@ -3,19 +3,28 @@
|
|||
# coding: utf-8
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/20 21:35
|
||||
@Usage :
|
||||
@Usage : 测试相关画图设置
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
result_file_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_C\RNet_C_timestamp120_feature10_result.csv"
|
||||
|
||||
mse_file_name="E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D\RNet_D_timestamp120_feature10_mse_predict1.csv"
|
||||
|
||||
|
||||
max_file_name="E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D\RNet_D_timestamp120_feature10_max_predict1.csv"
|
||||
|
||||
|
||||
|
||||
def plot_result(result_data):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (2.5, 2),
|
||||
'figure.figsize': (2.8, 2),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
|
|
@ -27,42 +36,97 @@ def plot_result(result_data):
|
|||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001)
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.5,label="predict")
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
plt.axhline(0.5, c='red', label='Failure threshold',lw=1)
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=5, color='red',
|
||||
verticalalignment='top')
|
||||
plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.',lw=1,label='real fault')
|
||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
plt.text(result_data.shape[0] * 4 / 5, 0.6, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 2.5}, fontdict=font1)
|
||||
plt.text(result_data.shape[0] * 1 / 3, 0.4, "Norm", fontsize=5, color='black', verticalalignment='top',
|
||||
'pad': 1.5,
|
||||
'linewidth':0.1}, fontdict=font1)
|
||||
plt.text(result_data.shape[0] * 1 / 3, 0.6, "Norm", fontsize=5, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 2.5}, fontdict=font1)
|
||||
'pad': 1.5,'linewidth':0.1}, fontdict=font1)
|
||||
|
||||
|
||||
indices = [result_data.shape[0] *i / 4 for i in range(5)]
|
||||
print(indices)
|
||||
classes = ['N', 'IF', 'OF','TRC', 'oo'] # for i in range(cls):
|
||||
classes = ['01/09/17', '08/09/17', '15/09/17','22/09/17', '29/09/17']
|
||||
|
||||
indices1 = [ i / 4 for i in range(5)]
|
||||
classes1 = ['0', '0.25', 'Threshold','0.75', '1']
|
||||
|
||||
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=5)
|
||||
plt.xlabel('Predicted label', fontsize=5)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks([index for index in indices1], classes1)
|
||||
plt.ylabel('Confidence', fontsize=5)
|
||||
plt.xlabel('Time', fontsize=5)
|
||||
plt.tight_layout()
|
||||
# plt.legend(loc='best',edgecolor='black',fontsize=3)
|
||||
plt.legend(loc='best',frameon=False,fontsize=3)
|
||||
# plt.grid()
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
file_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_C\RNet_C_timestamp120_feature10_result.csv"
|
||||
def plot_MSE(total_MSE,total_max):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (2.7, 2),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 5,
|
||||
'ytick.labelsize': 5,
|
||||
'legend.fontsize': 5,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
|
||||
result_data=total_MSE
|
||||
# 画出 y=1 这条水平线
|
||||
# plt.axhline(0.5, c='red', label='Failure threshold',lw=1)
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.',lw=0.5,label="real fault")
|
||||
|
||||
indices = [result_data.shape[0] * i / 4 for i in range(5)]
|
||||
classes = ['01/09/17', '08/09/17', '15/09/17', '22/09/17', '29/09/17']
|
||||
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
plt.ylabel('Mse', fontsize=5)
|
||||
plt.xlabel('Time', fontsize=5)
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
plt.plot(total_max,"--",label="max",linewidth=0.5)
|
||||
plt.plot(total_MSE,label="mse",linewidth=0.5,color='purple')
|
||||
plt.legend(loc='best',frameon=False,fontsize=5)
|
||||
|
||||
# plt.plot(total_mean)
|
||||
# plt.plot(min)
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
def test_result(file_name:str=result_file_name):
|
||||
# result_data = np.recfromcsv(file_name)
|
||||
result_data = np.loadtxt(file_name, delimiter=",")
|
||||
result_data = np.array(result_data)
|
||||
|
|
@ -71,3 +135,24 @@ if __name__ == '__main__':
|
|||
print(result_data)
|
||||
print(result_data.shape)
|
||||
plot_result(result_data)
|
||||
|
||||
|
||||
|
||||
def test_mse(mse_file_name:str=mse_file_name,max_file_name:str=max_file_name):
|
||||
mse_data = np.loadtxt(mse_file_name, delimiter=",")
|
||||
max_data = np.loadtxt(max_file_name,delimiter=',')
|
||||
mse_data = np.array(mse_data)
|
||||
max_data = np.array(max_data)
|
||||
print(mse_data.shape)
|
||||
print(max_data.shape)
|
||||
plot_MSE(mse_data,max_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_mse()
|
||||
test_result()
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue