leecode更新
This commit is contained in:
parent
663414a446
commit
32ebfcceb8
|
|
@ -0,0 +1,73 @@
|
|||
package com.markilue.leecode.greedy;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
|
||||
/**
|
||||
* @BelongsProject: Leecode
|
||||
* @BelongsPackage: com.markilue.leecode.greedy
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-11-17 20:29
|
||||
* @Description: TODO 力扣452题 用最少的数量的箭引爆气球:
|
||||
* 有一些球形气球贴在一堵用 XY 平面表示的墙面上。墙面上的气球记录在整数数组 points ,其中points[i] = [xstart, xend] 表示水平直径在 xstart 和 xend之间的气球。你不知道气球的确切 y 坐标。
|
||||
* 一支弓箭可以沿着 x 轴从不同点 完全垂直 地射出。在坐标 x 处射出一支箭,若有一个气球的直径的开始和结束坐标为 xstart,xend, 且满足 xstart ≤ x ≤ xend,则该气球会被 引爆 。可以射出的弓箭的数量 没有限制 。 弓箭一旦被射出之后,可以无限地前进。
|
||||
* 给你一个数组 points ,返回引爆所有气球所必须射出的 最小 弓箭数 。
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class FindMinArrowShots {
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
int[][] points = {{1, 2}, {3, 4},{5, 6},{7, 8}};
|
||||
int[][] points1 = {{10, 16}, {2, 8},{1, 6},{7, 12}};
|
||||
int[][] points2 = {{1, 2}, {2, 3},{3, 4},{4, 5}};
|
||||
int[][] points3 = {{1, 4}, {1, 3},{3, 4},{4, 5}};
|
||||
int[][] points4 = {{-2147483646, -2147483645}, {2147483646, 2147483647}};
|
||||
int[][] points5 = {{9,12}, {1, 10},{4, 11},{8, 12}, {3, 9}, {6, 9},{6, 7}};
|
||||
// System.out.println((long)(-2147483646)-(long)(2147483647)); //-4294967293
|
||||
// System.out.println((int)((long)(-2147483646)-(long)(2147483647)));//3
|
||||
// System.out.println((int)(-2147483646)-(int)(2147483647)); //3
|
||||
System.out.println(findMinArrowShots(points));
|
||||
}
|
||||
|
||||
/**
|
||||
* 本人思路:先根据xstart进行排序,然后寻找第一个比xend大的xstart则result+1以此类推
|
||||
* 排序的时间复杂度为O(nlogn),所以总复杂度为O(nlogn)
|
||||
* 速度击败12.72%,内存击败88.71%
|
||||
* @param points
|
||||
* @return
|
||||
*/
|
||||
public int findMinArrowShots(int[][] points) {
|
||||
|
||||
//按xstart进行排序
|
||||
Arrays.sort(points, new Comparator<int[]>() {
|
||||
@Override
|
||||
public int compare(int[] o1, int[] o2) {
|
||||
//警惕两数相减超过int范围
|
||||
return o1[0] == o2[0] ? (o1[1] > o2[1]?1:-1) : (o1[0] > o2[0]?1:-1);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
//遍历points,找到比xend大的xstart就+1
|
||||
int result=1;
|
||||
int lastEnd=points[0][1];
|
||||
for (int i = 1; i < points.length; i++) {
|
||||
|
||||
if(points[i][0]>lastEnd){
|
||||
result++;
|
||||
lastEnd=points[i][1];
|
||||
}
|
||||
//判断第二个数是不是比lastEnd还小,如果还小就是子集,那么就还得该lastEnd
|
||||
if(points[i][1]<lastEnd){
|
||||
lastEnd=points[i][1];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/11/14 10:43
|
||||
@Usage : 画sigmoid
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
y = 1 / (1 + np.exp(-x))
|
||||
# dy=y*(1-y)
|
||||
return y
|
||||
pass
|
||||
|
||||
|
||||
def plot_sigmoid():
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (8, 6),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 15,
|
||||
'ytick.labelsize': 15,
|
||||
'legend.fontsize': 15,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
x=np.arange(-10,10,0.02)
|
||||
y=sigmoid(x)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=2)
|
||||
plt.plot(x,y,c='black',lw=2)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plot_sigmoid()
|
||||
|
|
@ -47,24 +47,20 @@ def plot_result_banda(result_data):
|
|||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 5,
|
||||
'ytick.labelsize': 5,
|
||||
'xtick.labelsize': 6,
|
||||
'ytick.labelsize': 6,
|
||||
'legend.fontsize': 5,
|
||||
'font.family': 'Times New Roman',
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
font2 = {'family': 'Times New Roman', 'weight': 'normal','size':7} # 设置坐标标签的字体大小,字体
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
|
|
@ -89,24 +85,27 @@ def plot_result_banda(result_data):
|
|||
# pad调整label与坐标轴之间的距离
|
||||
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||
# plt.yticks([index for index in indices1], classes1)
|
||||
plt.ylabel('Confidence', fontsize=5)
|
||||
plt.xlabel('Time', fontsize=5)
|
||||
plt.ylabel('Confidence',fontdict=font2)
|
||||
plt.xlabel('Time',fontdict=font2)
|
||||
plt.tight_layout()
|
||||
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
||||
plt.legend(loc='upper right', frameon=False, fontsize=4.5)
|
||||
# plt.grid()
|
||||
|
||||
# 局部方法图
|
||||
|
||||
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
|
||||
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
||||
bbox_transform=ax.transAxes)
|
||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001, label="predict")
|
||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
|
||||
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
||||
# 设置放大区间
|
||||
# 设置放大区间
|
||||
zone_left = int(result_data.shape[0] * 2 / 3 -160)
|
||||
zone_right = int(result_data.shape[0] * 2 / 3) + 40
|
||||
# zone_left = int(result_data.shape[0] * 2 / 3 +250)
|
||||
# zone_right = int(result_data.shape[0] * 2 / 3) + 450
|
||||
x = list(range(result_data.shape[0]))
|
||||
|
||||
# 坐标轴的扩展比例(根据实际数据调整)
|
||||
|
|
@ -140,27 +139,22 @@ def plot_result(result_data):
|
|||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 5,
|
||||
'ytick.labelsize': 5,
|
||||
'xtick.labelsize': 6,
|
||||
'ytick.labelsize': 6,
|
||||
'legend.fontsize': 5,
|
||||
'font.family': 'Times New Roman',
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
|
||||
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
|
||||
# verticalalignment='top')
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
|
||||
plt.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
|
|
@ -182,8 +176,8 @@ def plot_result(result_data):
|
|||
# pad调整label与坐标轴之间的距离
|
||||
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||
# plt.yticks([index for index in indices1], classes1)
|
||||
plt.ylabel('Confidence', fontsize=5)
|
||||
plt.xlabel('Time', fontsize=5)
|
||||
plt.ylabel('Confidence', fontdict=font2)
|
||||
plt.xlabel('Time', fontdict=font2)
|
||||
plt.tight_layout()
|
||||
# plt.legend(loc='best', edgecolor='black', fontsize=4)
|
||||
plt.legend(loc='best', frameon=False, fontsize=4.5)
|
||||
|
|
@ -193,8 +187,8 @@ def plot_result(result_data):
|
|||
axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
|
||||
bbox_to_anchor=(0.1, 0.1, 1, 1),
|
||||
bbox_transform=ax.transAxes)
|
||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001, label="predict")
|
||||
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
|
||||
axins.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
|
||||
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
|
||||
# 设置放大区间
|
||||
# 设置放大区间
|
||||
|
|
@ -587,6 +581,48 @@ def plot_mse(file_name_mse="../others_idea/mse",data:str=''):
|
|||
|
||||
plt.show()
|
||||
|
||||
def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_result_predict1.csv"):
|
||||
mse = np.loadtxt(file_name_mse, delimiter=",")
|
||||
|
||||
print("mse:",mse.shape)
|
||||
|
||||
need_shape=int(mse.shape[0]*2/3)
|
||||
# mse = mse[2000:2300]
|
||||
# mse = mse[1800:2150]
|
||||
# mse = mse[ need_shape+100:need_shape+377]
|
||||
mse = mse[ need_shape-300:need_shape-10]
|
||||
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (2.8, 2),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 6,
|
||||
'ytick.labelsize': 6,
|
||||
'legend.fontsize': 5,
|
||||
'font.family':'Times New Roman'
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
|
||||
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
indices = [mse.shape[0] * i / 4 for i in range(5)]
|
||||
# classes = ['07:21','08:21', '09:21', '10:21', '11:21']
|
||||
classes = ['01:58','02:58', '03:58', '04:58', '05:58']
|
||||
|
||||
plt.xticks([index for index in indices], classes, rotation=25) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# pad调整label与坐标轴之间的距离
|
||||
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
|
||||
plt.ylabel('Residual', fontdict=font2)
|
||||
plt.xlabel('Time', fontdict=font2)
|
||||
plt.tight_layout()
|
||||
plt.plot(mse[:], lw=0.5)
|
||||
|
||||
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_3d():
|
||||
|
|
@ -646,7 +682,7 @@ def test_result(file_name: str = result_file_name):
|
|||
# result_data = np.concatenate([result_data, data], axis=0)
|
||||
print(result_data)
|
||||
print(result_data.shape)
|
||||
plot_result_banda(result_data)
|
||||
plot_result(result_data)
|
||||
|
||||
|
||||
def test_mse(mse_file_name: str = mse_file_name, max_file_name: str = max_file_name):
|
||||
|
|
@ -739,17 +775,17 @@ def test_model_visualization(model_name=file_name):
|
|||
if __name__ == '__main__':
|
||||
# test_mse(fi)
|
||||
# test_result(
|
||||
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_banda_result3.csv')
|
||||
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_timestamp120_feature10_result.csv')
|
||||
# test_corr()
|
||||
# acc()
|
||||
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
|
||||
# test_bar(list)
|
||||
|
||||
# list=[99.99,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
|
||||
# plot_FNR1(list)
|
||||
list=[98.56,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
|
||||
plot_FNR1(list)
|
||||
# #
|
||||
list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
|
||||
plot_FNR2(list)
|
||||
# list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
|
||||
# plot_FNR2(list)
|
||||
|
||||
# 查看网络某一层的权重
|
||||
# test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt")
|
||||
|
|
@ -759,8 +795,8 @@ if __name__ == '__main__':
|
|||
# test_model_weight_l(file_name)
|
||||
|
||||
# 单独预测图
|
||||
plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
|
||||
|
||||
# plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
|
||||
# plot_mse_single('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D/banda\RNet_D_banda_mse_predict1.csv')
|
||||
#画3d图
|
||||
# plot_3d()
|
||||
|
||||
|
|
|
|||
|
|
@ -283,76 +283,7 @@ def showResult(step_two_model: tf.keras.Model, test_data, isPlot: bool = False,
|
|||
return total_result
|
||||
|
||||
|
||||
def plot_result(result_data):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (5, 5),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 20,
|
||||
'ytick.labelsize': 20,
|
||||
'legend.fontsize': 11.3,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
|
||||
verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
|
||||
|
||||
indices = range(result_data.shape[0])
|
||||
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=20)
|
||||
plt.xlabel('Predicted label', fontsize=20)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
|
||||
sns.set(font_scale=1.5)
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (5, 5),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 20,
|
||||
'ytick.labelsize': 20,
|
||||
'legend.fontsize': 11.3,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
confusion = confusion_matrix(true_labels, predict_labels)
|
||||
# confusion = confusion.astype('float') / confusion.sum(axis=1)[: np.newaxis] plt.figure()
|
||||
# sns.heatmap(confusion, annot=True,fmt="d",cmap="Greens")
|
||||
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
|
||||
indices = range(len(confusion))
|
||||
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
|
||||
# classes.append(str(i))
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=20)
|
||||
plt.xlabel('Predicted label', fontsize=20)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -204,12 +204,21 @@ def test_corr(file_name=source_path,N=10):
|
|||
pass
|
||||
|
||||
|
||||
def plot_raw_data():
|
||||
# data, label = read_data(file_name='G:\data\SCADA数据\jb4q_8_delete_total_zero.csv', isNew=False)
|
||||
# data = np.loadtxt('G:\data\SCADA数据/normalization.csv',delimiter=',')
|
||||
data = pd.read_csv('G:\data\SCADA数据/normalization.csv')
|
||||
print(data.shape)
|
||||
print(data)
|
||||
plt.plot(data)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_mse()
|
||||
# test_result()
|
||||
test_corr()
|
||||
# test_corr()
|
||||
plot_raw_data()
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,12 +28,12 @@ class Joint_Monitoring(keras.Model):
|
|||
super(Joint_Monitoring, self).__init__()
|
||||
# step one
|
||||
self.RepDCBlock1 = RevConvBlock(num=3, kernel_size=5)
|
||||
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME')
|
||||
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||
self.upsample1 = tf.keras.layers.UpSampling1D(size=2)
|
||||
|
||||
self.DACU2 = DynamicChannelAttention()
|
||||
self.RepDCBlock2 = RevConvBlock(num=3, kernel_size=3)
|
||||
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME')
|
||||
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
|
||||
self.upsample2 = tf.keras.layers.UpSampling1D(size=2)
|
||||
|
||||
self.DACU3 = DynamicChannelAttention()
|
||||
|
|
@ -382,16 +382,18 @@ class Joint_Monitoring(keras.Model):
|
|||
|
||||
class RevConv(keras.layers.Layer):
|
||||
|
||||
def __init__(self, kernel_size=3):
|
||||
def __init__(self, kernel_size=3,dilation_rate=2):
|
||||
# 调用父类__init__()方法
|
||||
super(RevConv, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
|
||||
def get_config(self):
|
||||
# 自定义层里面的属性
|
||||
config = (
|
||||
{
|
||||
'kernel_size': self.kernel_size
|
||||
'kernel_size': self.kernel_size,
|
||||
'dilation_rate': self.dilation_rate
|
||||
}
|
||||
)
|
||||
base_config = super(RevConv, self).get_config()
|
||||
|
|
@ -402,10 +404,10 @@ class RevConv(keras.layers.Layer):
|
|||
_, _, output_dim = input_shape[0], input_shape[1], input_shape[2]
|
||||
self.conv1 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=self.kernel_size, strides=1,
|
||||
padding='causal',
|
||||
dilation_rate=4)
|
||||
dilation_rate=self.dilation_rate)
|
||||
|
||||
self.conv2 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=1, strides=1, padding='causal',
|
||||
dilation_rate=4)
|
||||
dilation_rate=self.dilation_rate)
|
||||
# self.b2 = tf.keras.layers.BatchNormalization()
|
||||
|
||||
# self.b3 = tf.keras.layers.BatchNormalization()
|
||||
|
|
@ -440,7 +442,7 @@ class RevConvBlock(keras.layers.Layer):
|
|||
self.kernel_size = kernel_size
|
||||
self.L = []
|
||||
for i in range(num):
|
||||
RepVGG = RevConv(kernel_size=kernel_size)
|
||||
RepVGG = RevConv(kernel_size=kernel_size,dilation_rate=(i+1)*2)
|
||||
self.L.append(RepVGG)
|
||||
|
||||
def get_config(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue