leecode更新

This commit is contained in:
markilue 2022-10-26 11:26:43 +08:00
parent 82d051a82a
commit 4b279d5f1d
6 changed files with 446 additions and 32 deletions

View File

@ -0,0 +1,110 @@
package com.markilue.leecode.greedy;
import org.junit.Test;
/**
* @BelongsProject: Leecode
* @BelongsPackage: com.markilue.leecode.greedy
* @Author: markilue
* @CreateTime: 2022-10-26 10:48
* @Description:
* TODO 力扣55题 跳跃游戏
* 给定一个非负整数数组 nums 你最初位于数组的 第一个下标
* 数组中的每个元素代表你在该位置可以跳跃的最大长度
* 判断你是否能够到达最后一个下标
* @Version: 1.0
*/
public class CanJump {
@Test
public void test(){
int[] nums = {2, 3, 1, 1, 4};
System.out.println(canJump1(nums));
}
@Test
public void test1(){
int[] nums = {3,2,1,0,4};
System.out.println(canJump1(nums));
}
/**
* 自己的思路似乎可以使用递归去做
* 要判断最后一个下标能否到就是判断上一个能到这个下标的位置能不能到
* 速度击败94.15%内存击败39.95%
* @param nums
* @return
*/
public boolean canJump(int[] nums) {
return jump(nums,nums.length-1);
}
public boolean jump(int[] nums,int j) {
//怎么样也到不了了
if(j<0){
return false;
}
if(j==0){
return true;
}
//寻找上一个能到j位置的点
int i=j-1;
for (int k=1; i >= 0; i--,k++) {
if(nums[i]<k){
continue;
}else {
//找到了上一次能到j位置的点
break;
}
}
return jump(nums,i);
}
/**
* 代码随想录贪心利用计算最大覆盖面积去算
*
* 速度击败94.15%内存击败13.37%
* @param nums
* @return
*/
public boolean canJump1(int[] nums) {
if(nums.length==1){
return true;
}
int cover=0;
for (int i = 0; i <= cover; i++) {
//最大覆盖面积
cover=Math.max(cover,nums[i]+i);
if(cover>= nums.length-1)return true;
}
return false;
}
/**
* 评论区动态规划法与自己的方法类似但是这个最小步长比较巧妙
* @param nums
* @return
*/
public boolean canJump2(int[] nums) {
int length = nums.length;
if (length == 1) return true;
int minStep = 1; //定义一个数为达到最后最后一个结点最小需要的步数
for (int i = length-2; i >0; i--) { //从倒数第二个往第二个开始遍历
if (nums[i]<minStep){ // 如果当前元素的值小于最小步数,则到达最后一个元素的最小步数+1;
minStep++;
}else {
minStep = 1; //如果当前元素的值大于或等于最小步数则一定能到达最后一个元素
// 此时可以就当前元素认为是最后一个元素并对于前一个元素来说最小步数为1;
}
}
return nums[0] >= minStep; //此时minStep为达到"最后一个元素"(并不是nums[length-1])的最小步数只要判断第一个元素的值是否大于或等于最小步数就可以了;
}
}

View File

@ -13,8 +13,8 @@ import java.util.Comparator;
* @Description:
* TODO 力扣455题 分发饼干:
* 假设你是一位很棒的家长想要给你的孩子们一些小饼干但是每个孩子最多只能给一块饼干
* 对每个孩子 i都有一个胃口值 g[i]这是能让孩子们满足胃口的饼干的最小尺寸
* 并且每块饼干 j都有一个尺寸 s[j] 如果 s[j] >= g[i]我们可以将这个饼干 j 分配给孩子 i 这个孩子会得到满足
* 对每个孩子 i都有一个胃口值g[i]这是能让孩子们满足胃口的饼干的最小尺寸
* 并且每块饼干 j都有一个尺寸 s[j]如果 s[j]>= g[i]我们可以将这个饼干 j 分配给孩子 i 这个孩子会得到满足
* 你的目标是尽可能满足越多数量的孩子并输出这个最大数值
*
* @Version: 1.0

View File

@ -0,0 +1,112 @@
package com.markilue.leecode.greedy;
import org.junit.Test;
/**
* @BelongsProject: Leecode
* @BelongsPackage: com.markilue.leecode.greedy
* @Author: markilue
* @CreateTime: 2022-10-26 09:49
* @Description: TODO 力扣122题 买卖股票的最佳时机II:
* 给你一个整数数组 prices 其中prices[i] 表示某支股票第 i 天的价格
* 在每一天你可以决定是否购买和/或出售股票你在任何时候最多只能持有 一股 股票你也可以先购买然后在 同一天 出售
* 返回 你能获得的 最大 利润
* @Version: 1.0
*/
public class MaxProfit {
@Test
public void test(){
int[] prices = {7, 1, 5, 3, 6, 4};
System.out.println(maxProfit2(prices));
}
@Test
public void test1(){
int[] prices = {1,2,3,4,5};
System.out.println(maxProfit(prices));
}
@Test
public void test2(){
int[] prices = {7,6,4,3,15};
System.out.println(maxProfit(prices));
}
/**
* 思路由于股票可以当天买当天卖所以逻辑变成了
* 只要下一个天比今天小那么今天卖出去反之则不卖
* 速度击败81.82%内存击败98.42%
* @param prices
* @return
*/
public int maxProfit(int[] prices) {
//记录买入净额
int buy = prices[0];
int total = 0;
for (int i = 1; i < prices.length; i++) {
if (prices[i] - prices[i - 1] > 0) {
if(i==prices.length-1){
//比他大还是最后一组必须卖
total += prices[i] - buy;
}
//不卖
continue;
} else {
//
total += prices[i-1] - buy;
buy=prices[i];
}
}
return total;
}
/**
* 代码随想录贪心速度击败81.82%内存击败35.76%
* @param prices
* @return
*/
public int maxProfit1(int[] prices) {
//记录买入净额
int total = 0;
for (int i = 1; i < prices.length; i++) {
//省去了判断最后一次的操作但是要进行多次加法
total+=Math.max(prices[i]-prices[i-1],0);
}
return total;
}
/**
* 官方动态规划法
* 定义状态 dp[i][0] 表示第 i 天交易完后手里没有股票的最大利润p[i][1] 表示第 i 天交易完后手里持有一支股票的最大利润i 0 开始
* dp[i][0] 的转移方程
* 1如果这一天交易完后手里没有股票那么可能的转移状态为前一天已经没有股票 dp[i1][0]
* 2或者前一天结束的时候手里持有一支股票 dp[i1][1]这时候我们要将其卖出并获得 prices[i] 的收益
* dp[i][0]类似
* @param prices
* @return
*/
public int maxProfit2(int[] prices) {
int n = prices.length;
int dp0 = 0, dp1 = -prices[0];
for (int i = 1; i < n; ++i) {
int newDp0 = Math.max(dp0, dp1 + prices[i]);
int newDp1 = Math.max(dp1, dp0 - prices[i]);
dp0 = newDp0;
dp1 = newDp1;
}
return dp0;
}
}

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# coding: utf-8
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import random
@ -23,6 +24,8 @@ max_file_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_m
source_path = "G:\data\SCADA数据\jb4q_8_delete_total_zero.csv"
list = [64.16, 65.26, 65.11, 66.6, 67.16, 66.28, 73.86, 75.24, 73.98, 76.7, 98.86, 99.45, 99.97]
def plot_result(result_data):
parameters = {
@ -127,21 +130,21 @@ def plot_MSE(total_MSE, total_max):
pass
def plot_Corr(data, label):
def plot_Corr(data, size: int = 1):
parameters = {
'figure.dpi': 600,
'figure.figsize': (2.8, 2),
'figure.figsize': (2.8 * size, 2 * size),
'savefig.dpi': 600,
'xtick.direction': 'inout',
'ytick.direction': 'inout',
'xtick.labelsize': 3,
'ytick.labelsize': 3,
'legend.fontsize': 5,
'xtick.labelsize': 3 * size,
'ytick.labelsize': 3 * size,
'legend.fontsize': 5 * size,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 4} # 设置坐标标签的字体大小,字体
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 4 * size} # 设置坐标标签的字体大小,字体
print("计算皮尔逊相关系数")
pd_data = pd.DataFrame(data)
@ -149,11 +152,12 @@ def plot_Corr(data, label):
print(person)
# 画热点图heatmap
cmap = sns.heatmap(person, annot=True, annot_kws={
'fontsize': 2.6
'fontsize': 2.6 * size
})
classes = ['Gs', 'Gio', 'Gip', 'Gp', 'Gwt', 'En', 'Gft', 'Grt', 'Gwt', 'Et', 'Rs', 'Ap', 'Ws', 'Dw', 'Ges', 'Gt', 'Vx','Vy']
classes = ['Gs', 'Gio', 'Gip', 'Gp', 'Gwt', 'En', 'Gft', 'Grt', 'Gwt', 'Et', 'Rs', 'Ap', 'Ws', 'Dw', 'Ges', 'Gt',
'Vx', 'Vy']
indices = range(len(person))
plt.title("Heatmap of correlation coefficient matrix", size=6, fontdict=font1)
plt.title("Heatmap of correlation coefficient matrix", size=6 * size, fontdict=font1)
# pad调整label与坐标轴之间的距离
plt.tick_params(bottom=False, top=False, left=False, right=False, direction='inout', length=2, width=0.5, pad=1)
plt.xticks([index + 0.5 for index in indices], classes, rotation=0) # 设置横坐标方向rotation=45为45度倾斜
@ -161,7 +165,7 @@ def plot_Corr(data, label):
# 调整色带的标签:
cbar = cmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=4, labelcolor="black", length=2, width=0.5,pad=1)
cbar.ax.tick_params(labelsize=4 * size, labelcolor="black", length=2, width=0.5, pad=1)
cbar.ax.set_ylabel(ylabel="color scale", color="black", loc="center", fontdict=font1)
# plt.axis('off') # 去坐标轴
@ -171,6 +175,179 @@ def plot_Corr(data, label):
pass
def plot_bar(y_data):
parameters = {
'figure.dpi': 600,
'figure.figsize': (10, 6),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 15,
'ytick.labelsize': 15,
'legend.fontsize': 12,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
x_width = [ i for i in range(0, len(y_data))]
# x2_width = [i + 0.3 for i in x_width]
plt.bar(x_width[0], y_data[0], lw=1, color=['#FAF4E1'], width=0.5, label="CNN", edgecolor='black')
plt.bar(x_width[1] , y_data[1], lw=1, color=['#F5E3C4'], width=0.5, label="GRU", edgecolor='black')
plt.bar(x_width[2] , y_data[2], lw=1, color=['#EBC99D'], width=0.5, label="CNN-GRU", edgecolor='black')
plt.bar(x_width[3] , y_data[3], lw=1, color=['#FFC79C'], width=0.5, label="DCConv", edgecolor='black')
plt.bar(x_width[4] , y_data[4], lw=1, color=['#BEE9C7'], width=0.5, label="RepDCConv", edgecolor='black')
plt.bar(x_width[5] , y_data[5], lw=1, color=['#B8E9D0'], width=0.5,label="RNet-MSE", edgecolor='black')
plt.bar(x_width[6] , y_data[6], lw=1, color=['#B9E9E2'], width=0.5, label="RNet", edgecolor='black')
plt.bar(x_width[7] , y_data[7], lw=1, color=['#D6E6F2'], width=0.5, label="RNet-SE", edgecolor='black')
plt.bar(x_width[8] , y_data[8], lw=1, color=['#B4D1E9'], width=0.5, label="RNet-L", edgecolor='black')
plt.bar(x_width[9] , y_data[9], lw=1, color=['#AEB5EE'], width=0.5, label="RNet-D", edgecolor='black')
plt.bar(x_width[10] , y_data[10], lw=1, color=['#D2D3FC'], width=0.5, label="ResNet-18", edgecolor='black')
plt.bar(x_width[11] , y_data[11], lw=1, color=['#D5A9FF'], width=0.5, label="ResNet-C", edgecolor='black')
plt.bar(x_width[12] , y_data[12], lw=1, color=['#E000F5'], width=0.5, label="JMNet", edgecolor='black')
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
plt.xticks([])
plt.ylabel('False Positive Rate(%)',fontsize=18)
# plt.xlabel('Time', fontsize=5)
# plt.tight_layout()
num1, num2, num3, num4 = 0.1, 1, 3, 0
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, markerscale=0.5)
plt.ylim([-0.01, 5])
plt.show()
def acc(y_data=list):
parameters = {
'figure.dpi': 600,
'figure.figsize': (10, 6),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 15,
'ytick.labelsize': 15,
'legend.fontsize': 12,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
x_width = [ i/2 for i in range(0, len(y_data))]
# x2_width = [i + 0.3 for i in x_width]
plt.bar(x_width[0] , y_data[0], lw=1, color=['#FAF4E1'], width=0.25, label="CNN", edgecolor='black')
plt.bar(x_width[1] , y_data[1], lw=1, color=['#F5E3C4'], width=0.25, label="GRU", edgecolor='black')
plt.bar(x_width[2] , y_data[2], lw=1, color=['#EBC99D'], width=0.25, label="CNN-GRU", edgecolor='black')
plt.bar(x_width[3] , y_data[3], lw=1, color=['#FFC79C'], width=0.25, label="DCConv", edgecolor='black')
plt.bar(x_width[4] , y_data[4], lw=1, color=['#BEE9C7'], width=0.25, label="RepDCConv", edgecolor='black')
plt.bar(x_width[5] , y_data[5], lw=1, color=['#B8E9D0'], width=0.25,label="RNet-MSE", edgecolor='black')
plt.bar(x_width[6] , y_data[6], lw=1, color=['#B9E9E2'], width=0.25, label="RNet", edgecolor='black')
plt.bar(x_width[7] , y_data[7], lw=1, color=['#D6E6F2'], width=0.25, label="RNet-SE", edgecolor='black')
plt.bar(x_width[8] , y_data[8], lw=1, color=['#B4D1E9'], width=0.25, label="RNet-L", edgecolor='black')
plt.bar(x_width[9] , y_data[9], lw=1, color=['#AEB5EE'], width=0.25, label="RNet-D", edgecolor='black')
plt.bar(x_width[10] , y_data[10], lw=1, color=['#D2D3FC'], width=0.25, label="ResNet-18", edgecolor='black')
plt.bar(x_width[11] , y_data[11], lw=1, color=['#D5A9FF'], width=0.25, label="ResNet-C", edgecolor='black')
plt.bar(x_width[12] , y_data[12], lw=1, color=['#E000F5'], width=0.25, label="JMNet", edgecolor='black')
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
plt.xticks([])
plt.ylabel('Accuracy(%)',fontsize=18)
# plt.xlabel('Time', fontsize=5)
# plt.tight_layout()
num1, num2, num3, num4 = 0.08, 1, 3, 0
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, markerscale=0.5)
plt.ylim([60, 105])
plt.show()
def plot_FNR1(y_data):
parameters = {
'figure.dpi': 600,
'figure.figsize': (10, 6),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 15,
'ytick.labelsize': 15,
'legend.fontsize': 12,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
x_width = range(0, len(y_data))
# x2_width = [i + 0.3 for i in x_width]
plt.bar(x_width[0], y_data[0], lw=1, color=['#FAF4E1'], width=0.5*5/6, label="CNN", edgecolor='black')
plt.bar(x_width[1] , y_data[1], lw=1, color=['#F5E3C4'], width=0.5*5/6, label="GRU", edgecolor='black')
plt.bar(x_width[2] , y_data[2], lw=1, color=['#EBC99D'], width=0.5*5/6, label="CNN-GRU", edgecolor='black')
plt.bar(x_width[3] , y_data[3], lw=1, color=['#FFC79C'], width=0.5*5/6, label="DCConv", edgecolor='black')
plt.bar(x_width[4] , y_data[4], lw=1, color=['#BEE9C7'], width=0.5*5/6, label="RepDCConv", edgecolor='black')
plt.bar(x_width[5] , y_data[5], lw=1, color=['#B8E9D0'], width=0.5*5/6,label="RNet-MSE", edgecolor='black')
plt.bar(x_width[6] , y_data[6], lw=1, color=['#B9E9E2'], width=0.5*5/6, label="RNet", edgecolor='black')
plt.bar(x_width[7] , y_data[7], lw=1, color=['#D6E6F2'], width=0.5*5/6, label="RNet-SE", edgecolor='black')
plt.bar(x_width[8] , y_data[8], lw=1, color=['#B4D1E9'], width=0.5*5/6, label="RNet-L", edgecolor='black')
plt.bar(x_width[9] , y_data[9], lw=1, color=['#AEB5EE'], width=0.5*5/6, label="RNet-D", edgecolor='black')
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
plt.xticks([])
plt.ylabel('False Negative Rate(%)',fontsize=18)
# plt.xlabel('Time', fontsize=5)
# plt.tight_layout()
num1, num2, num3, num4 = 0.16, 1, 3, 0
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, markerscale=0.5)
plt.ylim([60, 100])
plt.show()
def plot_FNR2(y_data):
parameters = {
'figure.dpi': 600,
'figure.figsize': (10, 6),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 15,
'ytick.labelsize': 15,
'legend.fontsize': 12,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
x_width = range(0, len(y_data))
# x2_width = [i + 0.3 for i in x_width]
plt.bar(x_width[0] , y_data[0], lw=1, color=['#FAF4E1'], width=0.5*2/3, label="ResNet-18", edgecolor='black')
plt.bar(x_width[1] , y_data[1], lw=1, color=['#F5E3C4'], width=0.5*2/3, label="RNet-3", edgecolor='black')
plt.bar(x_width[2] , y_data[2], lw=1, color=['#EBC99D'], width=0.5*2/3, label="RNet-4", edgecolor='black')
plt.bar(x_width[3] , y_data[3], lw=1, color=['#FFC79C'], width=0.5*2/3, label="RNet-5", edgecolor='black')
plt.bar(x_width[4] , y_data[4], lw=1, color=['#D6E6F2'], width=0.5*2/3, label="RNet-34", edgecolor='black')
plt.bar(x_width[5] , y_data[5], lw=1, color=['#B4D1E9'], width=0.5*2/3, label="RNet-35", edgecolor='black')
plt.bar(x_width[6] , y_data[6], lw=1, color=['#AEB5EE'], width=0.5*2/3, label="RNet-45", edgecolor='black')
# plt.bar(x_width[7] + 2.0, y_data[10], lw=0.5, color=['#8085e9'], width=1, label="ResNet-18", edgecolor='black')
plt.bar(x_width[7] , y_data[7], lw=1, color=['#D5A9FF'], width=0.5*2/3, label="ResNet-C", edgecolor='black')
plt.bar(x_width[8] , y_data[8], lw=1, color=['#E000F5'], width=0.5*2/3, label="JMNet", edgecolor='black')
# plt.tick_params(bottom=False, top=False, left=True, right=False, direction='in', pad=1)
plt.xticks([])
plt.ylabel('False Negative Rate(%)',fontsize=18)
# plt.xlabel('Time', fontsize=5)
# plt.tight_layout()
num1, num2, num3, num4 = 0.16, 1, 3, 0
plt.legend(bbox_to_anchor=(num1, num2), loc=num3, borderaxespad=num4, ncol=5, frameon=False, markerscale=0.5)
plt.ylim([0, 5])
plt.show()
def test_result(file_name: str = result_file_name):
# result_data = np.recfromcsv(file_name)
result_data = np.loadtxt(file_name, delimiter=",")
@ -180,8 +357,10 @@ def test_result(file_name: str = result_file_name):
print(theshold)
print(theshold * 2 / 3)
# 计算误报率和漏报率
positive_rate=result_data[:int(theshold*2/3)][result_data[:int(theshold*2/3)] < 0.66].__len__()/(theshold*2/3)
negative_rate=result_data[int(theshold*2/3):][result_data[int(theshold*2/3):] > 0.66].__len__()/(theshold*1/3)
positive_rate = result_data[:int(theshold * 2 / 3)][result_data[:int(theshold * 2 / 3)] < 0.66].__len__() / (
theshold * 2 / 3)
negative_rate = result_data[int(theshold * 2 / 3):][result_data[int(theshold * 2 / 3):] > 0.66].__len__() / (
theshold * 1 / 3)
print("误报率:", positive_rate)
print("漏报率", negative_rate)
@ -208,13 +387,25 @@ def test_corr(file_name=source_path, N=10):
print(needed_data)
print(needed_data.shape)
# plot_original_data(needed_data)
person = plot_Corr(needed_data, label)
person = plot_Corr(needed_data, size=3)
person = np.array(person)
pass
def test_bar(y_data=list):
plot_bar(y_data)
if __name__ == '__main__':
# test_mse()
test_result(file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_345\RNet_345_timestamp120_feature10_result.csv')
# test_result(file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_C\RNet_C_timestamp120_feature10_result2.csv')
# test_corr()
pass
# acc()
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
# test_bar(list)
# list=[99.99,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
# plot_FNR1(list)
#
list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
plot_FNR2(list)

View File

@ -44,13 +44,13 @@ save_name = "./model/weight/{0}/weight".format(model_name,
feature_num,
batch_size,
EPOCH)
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
save_step_two_name = "./model/two_weight/{0}_weight_epoch5_9991_9995/weight".format(model_name,
time_stamp,
feature_num,
batch_size,
EPOCH)
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result2.csv".format(model_name,
time_stamp,
feature_num,
batch_size,
@ -651,7 +651,8 @@ def getResult(model: tf.keras.Model, healthy_data, healthy_label, unhealthy_data
if __name__ == '__main__':
total_data = loadData.execute(N=feature_num, file_name=file_name)
# total_data = loadData.execute(N=feature_num, file_name=file_name)
total_data = np.load("G:\data\SCADA数据\靖边8号处理后的数据\原始10SCADA数据/total_data.npy")
total_data = normalization(data=total_data)
train_data_healthy, train_label1_healthy, train_label2_healthy = get_training_data_overlapping(
total_data[:healthy_date, :], is_Healthy=True)
@ -675,11 +676,11 @@ if __name__ == '__main__':
### unhealthy_data.shape: (16594,10)
healthy_size, _, _ = train_data_healthy.shape
unhealthy_size, _, _ = train_data_unhealthy.shape
train_data, train_label1, train_label2, test_data, test_label1, test_label2 = split_test_data(
healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :, :],
healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
# train_data, train_label1, train_label2, test_data, test_label1, test_label2 = split_test_data(
# healthy_data=train_data_healthy[healthy_size - 2 * unhealthy_size:, :, :],
# healthy_label1=train_label1_healthy[healthy_size - 2 * unhealthy_size:, :],
# healthy_label2=train_label2_healthy[healthy_size - 2 * unhealthy_size:, ], unhealthy_data=train_data_unhealthy,
# unhealthy_label1=train_label1_unhealthy, unhealthy_label2=train_label2_unhealthy)
# train_step_two(step_one_model=step_one_model, step_two_model=step_two_model,
# train_data=train_data,
# train_label1=train_label1, train_label2=np.expand_dims(train_label2, axis=-1))

View File

@ -21,7 +21,7 @@ K = 18
namuda = 0.01
'''保存名称'''
save_name = "./model/{0}.h5".format(model_name,
save_name = "./model/{0}_9990_9998.h5".format(model_name,
time_stamp,
feature_num,
batch_size,