leecode更新
This commit is contained in:
parent
663414a446
commit
32ebfcceb8
|
|
@ -0,0 +1,73 @@
|
||||||
|
package com.markilue.leecode.greedy;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Comparator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @BelongsProject: Leecode
|
||||||
|
* @BelongsPackage: com.markilue.leecode.greedy
|
||||||
|
* @Author: markilue
|
||||||
|
* @CreateTime: 2022-11-17 20:29
|
||||||
|
* @Description: TODO 力扣452题 用最少的数量的箭引爆气球:
|
||||||
|
* 有一些球形气球贴在一堵用 XY 平面表示的墙面上。墙面上的气球记录在整数数组 points ,其中points[i] = [xstart, xend] 表示水平直径在 xstart 和 xend之间的气球。你不知道气球的确切 y 坐标。
|
||||||
|
* 一支弓箭可以沿着 x 轴从不同点 完全垂直 地射出。在坐标 x 处射出一支箭,若有一个气球的直径的开始和结束坐标为 xstart,xend, 且满足 xstart ≤ x ≤ xend,则该气球会被 引爆 。可以射出的弓箭的数量 没有限制 。 弓箭一旦被射出之后,可以无限地前进。
|
||||||
|
* 给你一个数组 points ,返回引爆所有气球所必须射出的 最小 弓箭数 。
|
||||||
|
* @Version: 1.0
|
||||||
|
*/
|
||||||
|
public class FindMinArrowShots {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void test() {
|
||||||
|
int[][] points = {{1, 2}, {3, 4},{5, 6},{7, 8}};
|
||||||
|
int[][] points1 = {{10, 16}, {2, 8},{1, 6},{7, 12}};
|
||||||
|
int[][] points2 = {{1, 2}, {2, 3},{3, 4},{4, 5}};
|
||||||
|
int[][] points3 = {{1, 4}, {1, 3},{3, 4},{4, 5}};
|
||||||
|
int[][] points4 = {{-2147483646, -2147483645}, {2147483646, 2147483647}};
|
||||||
|
int[][] points5 = {{9,12}, {1, 10},{4, 11},{8, 12}, {3, 9}, {6, 9},{6, 7}};
|
||||||
|
// System.out.println((long)(-2147483646)-(long)(2147483647)); //-4294967293
|
||||||
|
// System.out.println((int)((long)(-2147483646)-(long)(2147483647)));//3
|
||||||
|
// System.out.println((int)(-2147483646)-(int)(2147483647)); //3
|
||||||
|
System.out.println(findMinArrowShots(points));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 本人思路:先根据xstart进行排序,然后寻找第一个比xend大的xstart则result+1以此类推
|
||||||
|
* 排序的时间复杂度为O(nlogn),所以总复杂度为O(nlogn)
|
||||||
|
* 速度击败12.72%,内存击败88.71%
|
||||||
|
* @param points
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public int findMinArrowShots(int[][] points) {
|
||||||
|
|
||||||
|
//按xstart进行排序
|
||||||
|
Arrays.sort(points, new Comparator<int[]>() {
|
||||||
|
@Override
|
||||||
|
public int compare(int[] o1, int[] o2) {
|
||||||
|
//警惕两数相减超过int范围
|
||||||
|
return o1[0] == o2[0] ? (o1[1] > o2[1]?1:-1) : (o1[0] > o2[0]?1:-1);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
//遍历points,找到比xend大的xstart就+1
|
||||||
|
int result=1;
|
||||||
|
int lastEnd=points[0][1];
|
||||||
|
for (int i = 1; i < points.length; i++) {
|
||||||
|
|
||||||
|
if(points[i][0]>lastEnd){
|
||||||
|
result++;
|
||||||
|
lastEnd=points[i][1];
|
||||||
|
}
|
||||||
|
//判断第二个数是不是比lastEnd还小,如果还小就是子集,那么就还得该lastEnd
|
||||||
|
if(points[i][1]<lastEnd){
|
||||||
|
lastEnd=points[i][1];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2022/11/14 10:43
|
||||||
|
@Usage : 画sigmoid
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.keras
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid(x):
|
||||||
|
y = 1 / (1 + np.exp(-x))
|
||||||
|
# dy=y*(1-y)
|
||||||
|
return y
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def plot_sigmoid():
|
||||||
|
parameters = {
|
||||||
|
'figure.dpi': 600,
|
||||||
|
'figure.figsize': (8, 6),
|
||||||
|
'savefig.dpi': 600,
|
||||||
|
'xtick.direction': 'in',
|
||||||
|
'ytick.direction': 'in',
|
||||||
|
'xtick.labelsize': 15,
|
||||||
|
'ytick.labelsize': 15,
|
||||||
|
'legend.fontsize': 15,
|
||||||
|
}
|
||||||
|
plt.rcParams.update(parameters)
|
||||||
|
plt.figure()
|
||||||
|
x=np.arange(-10,10,0.02)
|
||||||
|
y=sigmoid(x)
|
||||||
|
# 画出 y=1 这条水平线
|
||||||
|
plt.axhline(0.5, c='red', label='Failure threshold', lw=2)
|
||||||
|
plt.plot(x,y,c='black',lw=2)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
plot_sigmoid()
|
||||||
|
|
@ -47,24 +47,20 @@ def plot_result_banda(result_data):
|
||||||
'savefig.dpi': 600,
|
'savefig.dpi': 600,
|
||||||
'xtick.direction': 'in',
|
'xtick.direction': 'in',
|
||||||
'ytick.direction': 'in',
|
'ytick.direction': 'in',
|
||||||
'xtick.labelsize': 5,
|
'xtick.labelsize': 6,
|
||||||
'ytick.labelsize': 5,
|
'ytick.labelsize': 6,
|
||||||
'legend.fontsize': 5,
|
'legend.fontsize': 5,
|
||||||
|
'font.family': 'Times New Roman',
|
||||||
}
|
}
|
||||||
plt.rcParams.update(parameters)
|
plt.rcParams.update(parameters)
|
||||||
fig, ax = plt.subplots(1, 1)
|
fig, ax = plt.subplots(1, 1)
|
||||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||||
|
font2 = {'family': 'Times New Roman', 'weight': 'normal','size':7} # 设置坐标标签的字体大小,字体
|
||||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
||||||
# 画出 y=1 这条水平线
|
# 画出 y=1 这条水平线
|
||||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
||||||
# 箭头指向上面的水平线
|
# 箭头指向上面的水平线
|
||||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
|
||||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
|
||||||
# verticalalignment='top')
|
|
||||||
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
|
|
||||||
# verticalalignment='top')
|
|
||||||
|
|
||||||
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||||
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
|
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||||
|
|
@ -89,24 +85,27 @@ def plot_result_banda(result_data):
|
||||||
# pad调整label与坐标轴之间的距离
|
# pad调整label与坐标轴之间的距离
|
||||||
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||||
# plt.yticks([index for index in indices1], classes1)
|
# plt.yticks([index for index in indices1], classes1)
|
||||||
plt.ylabel('Confidence', fontsize=5)
|
plt.ylabel('Confidence',fontdict=font2)
|
||||||
plt.xlabel('Time', fontsize=5)
|
plt.xlabel('Time',fontdict=font2)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
||||||
plt.legend(loc='upper right', frameon=False, fontsize=4.5)
|
plt.legend(loc='upper right', frameon=False, fontsize=4.5)
|
||||||
# plt.grid()
|
# plt.grid()
|
||||||
|
|
||||||
# 局部方法图
|
# 局部方法图
|
||||||
|
|
||||||
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
|
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
|
||||||
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
||||||
bbox_transform=ax.transAxes)
|
bbox_transform=ax.transAxes)
|
||||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001, label="predict")
|
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
|
||||||
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
||||||
# 设置放大区间
|
# 设置放大区间
|
||||||
# 设置放大区间
|
# 设置放大区间
|
||||||
zone_left = int(result_data.shape[0] * 2 / 3 -160)
|
zone_left = int(result_data.shape[0] * 2 / 3 -160)
|
||||||
zone_right = int(result_data.shape[0] * 2 / 3) + 40
|
zone_right = int(result_data.shape[0] * 2 / 3) + 40
|
||||||
|
# zone_left = int(result_data.shape[0] * 2 / 3 +250)
|
||||||
|
# zone_right = int(result_data.shape[0] * 2 / 3) + 450
|
||||||
x = list(range(result_data.shape[0]))
|
x = list(range(result_data.shape[0]))
|
||||||
|
|
||||||
# 坐标轴的扩展比例(根据实际数据调整)
|
# 坐标轴的扩展比例(根据实际数据调整)
|
||||||
|
|
@ -140,27 +139,22 @@ def plot_result(result_data):
|
||||||
'savefig.dpi': 600,
|
'savefig.dpi': 600,
|
||||||
'xtick.direction': 'in',
|
'xtick.direction': 'in',
|
||||||
'ytick.direction': 'in',
|
'ytick.direction': 'in',
|
||||||
'xtick.labelsize': 5,
|
'xtick.labelsize': 6,
|
||||||
'ytick.labelsize': 5,
|
'ytick.labelsize': 6,
|
||||||
'legend.fontsize': 5,
|
'legend.fontsize': 5,
|
||||||
|
'font.family': 'Times New Roman',
|
||||||
}
|
}
|
||||||
plt.rcParams.update(parameters)
|
plt.rcParams.update(parameters)
|
||||||
fig, ax = plt.subplots(1, 1)
|
fig, ax = plt.subplots(1, 1)
|
||||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||||
|
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
|
||||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
||||||
# 画出 y=1 这条水平线
|
# 画出 y=1 这条水平线
|
||||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
||||||
# 箭头指向上面的水平线
|
|
||||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
|
||||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
|
||||||
# verticalalignment='top')
|
|
||||||
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
|
|
||||||
# verticalalignment='top')
|
|
||||||
|
|
||||||
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
|
||||||
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
|
plt.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
|
||||||
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
|
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||||
horizontalalignment='center',
|
horizontalalignment='center',
|
||||||
bbox={'facecolor': 'grey',
|
bbox={'facecolor': 'grey',
|
||||||
|
|
@ -182,8 +176,8 @@ def plot_result(result_data):
|
||||||
# pad调整label与坐标轴之间的距离
|
# pad调整label与坐标轴之间的距离
|
||||||
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||||
# plt.yticks([index for index in indices1], classes1)
|
# plt.yticks([index for index in indices1], classes1)
|
||||||
plt.ylabel('Confidence', fontsize=5)
|
plt.ylabel('Confidence', fontdict=font2)
|
||||||
plt.xlabel('Time', fontsize=5)
|
plt.xlabel('Time', fontdict=font2)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
||||||
plt.legend(loc='best', frameon=False, fontsize=4.5)
|
plt.legend(loc='best', frameon=False, fontsize=4.5)
|
||||||
|
|
@ -193,8 +187,8 @@ def plot_result(result_data):
|
||||||
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
|
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
|
||||||
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
||||||
bbox_transform=ax.transAxes)
|
bbox_transform=ax.transAxes)
|
||||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001, label="predict")
|
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
|
||||||
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
axins.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
||||||
# 设置放大区间
|
# 设置放大区间
|
||||||
# 设置放大区间
|
# 设置放大区间
|
||||||
|
|
@ -587,6 +581,48 @@ def plot_mse(file_name_mse="../others_idea/mse",data:str=''):
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_result_predict1.csv"):
|
||||||
|
mse = np.loadtxt(file_name_mse, delimiter=",")
|
||||||
|
|
||||||
|
print("mse:",mse.shape)
|
||||||
|
|
||||||
|
need_shape=int(mse.shape[0]*2/3)
|
||||||
|
# mse = mse[2000:2300]
|
||||||
|
# mse = mse[1800:2150]
|
||||||
|
# mse = mse[ need_shape+100:need_shape+377]
|
||||||
|
mse = mse[ need_shape-300:need_shape-10]
|
||||||
|
|
||||||
|
parameters = {
|
||||||
|
'figure.dpi': 600,
|
||||||
|
'figure.figsize': (2.8, 2),
|
||||||
|
'savefig.dpi': 600,
|
||||||
|
'xtick.direction': 'in',
|
||||||
|
'ytick.direction': 'in',
|
||||||
|
'xtick.labelsize': 6,
|
||||||
|
'ytick.labelsize': 6,
|
||||||
|
'legend.fontsize': 5,
|
||||||
|
'font.family':'Times New Roman'
|
||||||
|
}
|
||||||
|
plt.rcParams.update(parameters)
|
||||||
|
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||||
|
indices = [mse.shape[0] * i / 4 for i in range(5)]
|
||||||
|
# classes = ['07:21','08:21', '09:21', '10:21', '11:21']
|
||||||
|
classes = ['01:58','02:58', '03:58', '04:58', '05:58']
|
||||||
|
|
||||||
|
plt.xticks([index for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||||
|
# pad调整label与坐标轴之间的距离
|
||||||
|
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||||
|
plt.ylabel('Residual', fontdict=font2)
|
||||||
|
plt.xlabel('Time', fontdict=font2)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.plot(mse[:], lw=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def plot_3d():
|
def plot_3d():
|
||||||
|
|
@ -646,7 +682,7 @@ def test_result(file_name: str = result_file_name):
|
||||||
# result_data = np.concatenate([result_data, data], axis=0)
|
# result_data = np.concatenate([result_data, data], axis=0)
|
||||||
print(result_data)
|
print(result_data)
|
||||||
print(result_data.shape)
|
print(result_data.shape)
|
||||||
plot_result_banda(result_data)
|
plot_result(result_data)
|
||||||
|
|
||||||
|
|
||||||
def test_mse(mse_file_name: str = mse_file_name, max_file_name: str = max_file_name):
|
def test_mse(mse_file_name: str = mse_file_name, max_file_name: str = max_file_name):
|
||||||
|
|
@ -739,17 +775,17 @@ def test_model_visualization(model_name=file_name):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_mse(fi)
|
# test_mse(fi)
|
||||||
# test_result(
|
# test_result(
|
||||||
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_banda_result3.csv')
|
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_timestamp120_feature10_result.csv')
|
||||||
# test_corr()
|
# test_corr()
|
||||||
# acc()
|
# acc()
|
||||||
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
|
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
|
||||||
# test_bar(list)
|
# test_bar(list)
|
||||||
|
|
||||||
# list=[99.99,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
|
list=[98.56,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
|
||||||
# plot_FNR1(list)
|
plot_FNR1(list)
|
||||||
# #
|
# #
|
||||||
list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
|
# list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
|
||||||
plot_FNR2(list)
|
# plot_FNR2(list)
|
||||||
|
|
||||||
# 查看网络某一层的权重
|
# 查看网络某一层的权重
|
||||||
# test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt")
|
# test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt")
|
||||||
|
|
@ -759,8 +795,8 @@ if __name__ == '__main__':
|
||||||
# test_model_weight_l(file_name)
|
# test_model_weight_l(file_name)
|
||||||
|
|
||||||
# 单独预测图
|
# 单独预测图
|
||||||
plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
|
# plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
|
||||||
|
# plot_mse_single('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D/banda\RNet_D_banda_mse_predict1.csv')
|
||||||
#画3d图
|
#画3d图
|
||||||
# plot_3d()
|
# plot_3d()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -283,76 +283,7 @@ def showResult(step_two_model: tf.keras.Model, test_data, isPlot: bool = False,
|
||||||
return total_result
|
return total_result
|
||||||
|
|
||||||
|
|
||||||
def plot_result(result_data):
|
|
||||||
parameters = {
|
|
||||||
'figure.dpi': 600,
|
|
||||||
'figure.figsize': (5, 5),
|
|
||||||
'savefig.dpi': 600,
|
|
||||||
'xtick.direction': 'in',
|
|
||||||
'ytick.direction': 'in',
|
|
||||||
'xtick.labelsize': 20,
|
|
||||||
'ytick.labelsize': 20,
|
|
||||||
'legend.fontsize': 11.3,
|
|
||||||
}
|
|
||||||
plt.rcParams.update(parameters)
|
|
||||||
plt.figure()
|
|
||||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
|
||||||
|
|
||||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=10)
|
|
||||||
# 画出 y=1 这条水平线
|
|
||||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
|
||||||
# 箭头指向上面的水平线
|
|
||||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
|
||||||
# alpha=0.9, overhang=0.5)
|
|
||||||
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
|
|
||||||
verticalalignment='top')
|
|
||||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
|
||||||
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
|
||||||
|
|
||||||
|
|
||||||
indices = range(result_data.shape[0])
|
|
||||||
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
|
|
||||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
|
||||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
|
||||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
|
||||||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
|
||||||
plt.ylabel('Actual label', fontsize=20)
|
|
||||||
plt.xlabel('Predicted label', fontsize=20)
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
|
|
||||||
sns.set(font_scale=1.5)
|
|
||||||
parameters = {
|
|
||||||
'figure.dpi': 600,
|
|
||||||
'figure.figsize': (5, 5),
|
|
||||||
'savefig.dpi': 600,
|
|
||||||
'xtick.direction': 'in',
|
|
||||||
'ytick.direction': 'in',
|
|
||||||
'xtick.labelsize': 20,
|
|
||||||
'ytick.labelsize': 20,
|
|
||||||
'legend.fontsize': 11.3,
|
|
||||||
}
|
|
||||||
plt.rcParams.update(parameters)
|
|
||||||
plt.figure()
|
|
||||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
|
||||||
confusion = confusion_matrix(true_labels, predict_labels)
|
|
||||||
# confusion = confusion.astype('float') / confusion.sum(axis=1)[: np.newaxis] plt.figure()
|
|
||||||
# sns.heatmap(confusion, annot=True,fmt="d",cmap="Greens")
|
|
||||||
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
|
|
||||||
indices = range(len(confusion))
|
|
||||||
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
|
|
||||||
# classes.append(str(i))
|
|
||||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
|
||||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
|
||||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
|
||||||
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
|
||||||
plt.ylabel('Actual label', fontsize=20)
|
|
||||||
plt.xlabel('Predicted label', fontsize=20)
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -204,12 +204,21 @@ def test_corr(file_name=source_path,N=10):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def plot_raw_data():
|
||||||
|
# data, label = read_data(file_name='G:\data\SCADA数据\jb4q_8_delete_total_zero.csv', isNew=False)
|
||||||
|
# data = np.loadtxt('G:\data\SCADA数据/normalization.csv',delimiter=',')
|
||||||
|
data = pd.read_csv('G:\data\SCADA数据/normalization.csv')
|
||||||
|
print(data.shape)
|
||||||
|
print(data)
|
||||||
|
plt.plot(data)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_mse()
|
# test_mse()
|
||||||
# test_result()
|
# test_result()
|
||||||
test_corr()
|
# test_corr()
|
||||||
|
plot_raw_data()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,12 @@ class Joint_Monitoring(keras.Model):
|
||||||
super(Joint_Monitoring, self).__init__()
|
super(Joint_Monitoring, self).__init__()
|
||||||
# step one
|
# step one
|
||||||
self.RepDCBlock1 = RevConvBlock(num=3, kernel_size=5)
|
self.RepDCBlock1 = RevConvBlock(num=3, kernel_size=5)
|
||||||
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME')
|
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||||
self.upsample1 = tf.keras.layers.UpSampling1D(size=2)
|
self.upsample1 = tf.keras.layers.UpSampling1D(size=2)
|
||||||
|
|
||||||
self.DACU2 = DynamicChannelAttention()
|
self.DACU2 = DynamicChannelAttention()
|
||||||
self.RepDCBlock2 = RevConvBlock(num=3, kernel_size=3)
|
self.RepDCBlock2 = RevConvBlock(num=3, kernel_size=3)
|
||||||
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME')
|
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||||
self.upsample2 = tf.keras.layers.UpSampling1D(size=2)
|
self.upsample2 = tf.keras.layers.UpSampling1D(size=2)
|
||||||
|
|
||||||
self.DACU3 = DynamicChannelAttention()
|
self.DACU3 = DynamicChannelAttention()
|
||||||
|
|
@ -382,16 +382,18 @@ class Joint_Monitoring(keras.Model):
|
||||||
|
|
||||||
class RevConv(keras.layers.Layer):
|
class RevConv(keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(self, kernel_size=3):
|
def __init__(self, kernel_size=3,dilation_rate=2):
|
||||||
# 调用父类__init__()方法
|
# 调用父类__init__()方法
|
||||||
super(RevConv, self).__init__()
|
super(RevConv, self).__init__()
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
# 自定义层里面的属性
|
# 自定义层里面的属性
|
||||||
config = (
|
config = (
|
||||||
{
|
{
|
||||||
'kernel_size': self.kernel_size
|
'kernel_size': self.kernel_size,
|
||||||
|
'dilation_rate': self.dilation_rate
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
base_config = super(RevConv, self).get_config()
|
base_config = super(RevConv, self).get_config()
|
||||||
|
|
@ -402,10 +404,10 @@ class RevConv(keras.layers.Layer):
|
||||||
_, _, output_dim = input_shape[0], input_shape[1], input_shape[2]
|
_, _, output_dim = input_shape[0], input_shape[1], input_shape[2]
|
||||||
self.conv1 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=self.kernel_size, strides=1,
|
self.conv1 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=self.kernel_size, strides=1,
|
||||||
padding='causal',
|
padding='causal',
|
||||||
dilation_rate=4)
|
dilation_rate=self.dilation_rate)
|
||||||
|
|
||||||
self.conv2 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=1, strides=1, padding='causal',
|
self.conv2 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=1, strides=1, padding='causal',
|
||||||
dilation_rate=4)
|
dilation_rate=self.dilation_rate)
|
||||||
# self.b2 = tf.keras.layers.BatchNormalization()
|
# self.b2 = tf.keras.layers.BatchNormalization()
|
||||||
|
|
||||||
# self.b3 = tf.keras.layers.BatchNormalization()
|
# self.b3 = tf.keras.layers.BatchNormalization()
|
||||||
|
|
@ -440,7 +442,7 @@ class RevConvBlock(keras.layers.Layer):
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.L = []
|
self.L = []
|
||||||
for i in range(num):
|
for i in range(num):
|
||||||
RepVGG = RevConv(kernel_size=kernel_size)
|
RepVGG = RevConv(kernel_size=kernel_size,dilation_rate=(i+1)*2)
|
||||||
self.L.append(RepVGG)
|
self.L.append(RepVGG)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue