leecode更新

This commit is contained in:
markilue 2022-11-17 22:04:42 +08:00
parent 663414a446
commit 32ebfcceb8
6 changed files with 212 additions and 111 deletions

View File

@ -0,0 +1,73 @@
package com.markilue.leecode.greedy;
import org.junit.Test;
import java.util.Arrays;
import java.util.Comparator;
/**
* @BelongsProject: Leecode
* @BelongsPackage: com.markilue.leecode.greedy
* @Author: markilue
* @CreateTime: 2022-11-17 20:29
* @Description: TODO 力扣452题 用最少的数量的箭引爆气球
* 有一些球形气球贴在一堵用 XY 平面表示的墙面上墙面上的气球记录在整数数组 points 其中points[i] = [xstart, xend] 表示水平直径在 xstart xend之间的气球你不知道气球的确切 y 坐标
* 一支弓箭可以沿着 x 轴从不同点 完全垂直 地射出在坐标 x 处射出一支箭若有一个气球的直径的开始和结束坐标为 xstartxend 且满足 xstart x xend则该气球会被 引爆 可以射出的弓箭的数量 没有限制 弓箭一旦被射出之后可以无限地前进
* 给你一个数组 points 返回引爆所有气球所必须射出的 最小 弓箭数
* @Version: 1.0
*/
public class FindMinArrowShots {
@Test
public void test() {
int[][] points = {{1, 2}, {3, 4},{5, 6},{7, 8}};
int[][] points1 = {{10, 16}, {2, 8},{1, 6},{7, 12}};
int[][] points2 = {{1, 2}, {2, 3},{3, 4},{4, 5}};
int[][] points3 = {{1, 4}, {1, 3},{3, 4},{4, 5}};
int[][] points4 = {{-2147483646, -2147483645}, {2147483646, 2147483647}};
int[][] points5 = {{9,12}, {1, 10},{4, 11},{8, 12}, {3, 9}, {6, 9},{6, 7}};
// System.out.println((long)(-2147483646)-(long)(2147483647)); //-4294967293
// System.out.println((int)((long)(-2147483646)-(long)(2147483647)));//3
// System.out.println((int)(-2147483646)-(int)(2147483647)); //3
System.out.println(findMinArrowShots(points));
}
/**
* 本人思路先根据xstart进行排序然后寻找第一个比xend大的xstart则result+1以此类推
* 排序的时间复杂度为O(nlogn),所以总复杂度为O(nlogn)
* 速度击败12.72%内存击败88.71%
* @param points
* @return
*/
public int findMinArrowShots(int[][] points) {
//按xstart进行排序
Arrays.sort(points, new Comparator<int[]>() {
@Override
public int compare(int[] o1, int[] o2) {
//警惕两数相减超过int范围
return o1[0] == o2[0] ? (o1[1] > o2[1]?1:-1) : (o1[0] > o2[0]?1:-1);
}
});
//遍历points找到比xend大的xstart就+1
int result=1;
int lastEnd=points[0][1];
for (int i = 1; i < points.length; i++) {
if(points[i][0]>lastEnd){
result++;
lastEnd=points[i][1];
}
//判断第二个数是不是比lastEnd还小如果还小就是子集那么就还得该lastEnd
if(points[i][1]<lastEnd){
lastEnd=points[i][1];
}
}
return result;
}
}

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# coding: utf-8
'''
@Author : dingjiawen
@Date : 2022/11/14 10:43
@Usage : 画sigmoid
@Desc :
'''
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras
import matplotlib.pyplot as plt
def sigmoid(x):
y = 1 / (1 + np.exp(-x))
# dy=y*(1-y)
return y
pass
def plot_sigmoid():
parameters = {
'figure.dpi': 600,
'figure.figsize': (8, 6),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 15,
'ytick.labelsize': 15,
'legend.fontsize': 15,
}
plt.rcParams.update(parameters)
plt.figure()
x=np.arange(-10,10,0.02)
y=sigmoid(x)
# 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold', lw=2)
plt.plot(x,y,c='black',lw=2)
plt.show()
if __name__ == '__main__':
plot_sigmoid()

View File

@ -47,24 +47,20 @@ def plot_result_banda(result_data):
'savefig.dpi': 600, 'savefig.dpi': 600,
'xtick.direction': 'in', 'xtick.direction': 'in',
'ytick.direction': 'in', 'ytick.direction': 'in',
'xtick.labelsize': 5, 'xtick.labelsize': 6,
'ytick.labelsize': 5, 'ytick.labelsize': 6,
'legend.fontsize': 5, 'legend.fontsize': 5,
'font.family': 'Times New Roman',
} }
plt.rcParams.update(parameters) plt.rcParams.update(parameters)
fig, ax = plt.subplots(1, 1) fig, ax = plt.subplots(1, 1)
plt.rc('font', family='Times New Roman') # 全局字体样式 plt.rc('font', family='Times New Roman') # 全局字体样式
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体 font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
font2 = {'family': 'Times New Roman', 'weight': 'normal','size':7} # 设置坐标标签的字体大小,字体
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict") plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
# 画出 y=1 这条水平线 # 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold', lw=1) plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
# 箭头指向上面的水平线 # 箭头指向上面的水平线
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
# verticalalignment='top')
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
# verticalalignment='top')
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault') plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault') # plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault')
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺 # plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
@ -89,24 +85,27 @@ def plot_result_banda(result_data):
# pad调整label与坐标轴之间的距离 # pad调整label与坐标轴之间的距离
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1) plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
# plt.yticks([index for index in indices1], classes1) # plt.yticks([index for index in indices1], classes1)
plt.ylabel('Confidence', fontsize=5) plt.ylabel('Confidence',fontdict=font2)
plt.xlabel('Time', fontsize=5) plt.xlabel('Time',fontdict=font2)
plt.tight_layout() plt.tight_layout()
# plt.legend(loc='best', edgecolor='black', fontsize=4) # plt.legend(loc='best', edgecolor='black', fontsize=4)
plt.legend(loc='upper right', frameon=False, fontsize=4.5) plt.legend(loc='upper right', frameon=False, fontsize=4.5)
# plt.grid() # plt.grid()
# 局部方法图 # 局部方法图
axins = inset_axes(ax, width="40%", height="30%", loc='lower left', axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
bbox_to_anchor=(0.1, 0.1, 1, 1), bbox_to_anchor=(0.1, 0.1, 1, 1),
bbox_transform=ax.transAxes) bbox_transform=ax.transAxes)
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001, label="predict") axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault') axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5) plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
# 设置放大区间 # 设置放大区间
# 设置放大区间 # 设置放大区间
zone_left = int(result_data.shape[0] * 2 / 3 -160) zone_left = int(result_data.shape[0] * 2 / 3 -160)
zone_right = int(result_data.shape[0] * 2 / 3) + 40 zone_right = int(result_data.shape[0] * 2 / 3) + 40
# zone_left = int(result_data.shape[0] * 2 / 3 +250)
# zone_right = int(result_data.shape[0] * 2 / 3) + 450
x = list(range(result_data.shape[0])) x = list(range(result_data.shape[0]))
# 坐标轴的扩展比例(根据实际数据调整) # 坐标轴的扩展比例(根据实际数据调整)
@ -140,27 +139,22 @@ def plot_result(result_data):
'savefig.dpi': 600, 'savefig.dpi': 600,
'xtick.direction': 'in', 'xtick.direction': 'in',
'ytick.direction': 'in', 'ytick.direction': 'in',
'xtick.labelsize': 5, 'xtick.labelsize': 6,
'ytick.labelsize': 5, 'ytick.labelsize': 6,
'legend.fontsize': 5, 'legend.fontsize': 5,
'font.family': 'Times New Roman',
} }
plt.rcParams.update(parameters) plt.rcParams.update(parameters)
fig, ax = plt.subplots(1, 1) fig, ax = plt.subplots(1, 1)
plt.rc('font', family='Times New Roman') # 全局字体样式 plt.rc('font', family='Times New Roman') # 全局字体样式
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体 font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict") plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.01, label="predict")
# 画出 y=1 这条水平线 # 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold', lw=1) plt.axhline(0.5, c='red', label='Failure threshold', lw=1)
# 箭头指向上面的水平线
# plt.arrow(result_data.shape[0]*2/3, 0.55, 2000, 0.085, width=0.00001, ec='red',length_includes_head=True)
# plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "real fault", fontsize=5, color='red',
# verticalalignment='top')
# plt.text(0, 0.55, "Threshold", fontsize=5, color='red',
# verticalalignment='top')
plt.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault')
# plt.axvline(415548, c='blue', ls='-.', lw=0.5, label='real fault') plt.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top', plt.text(result_data.shape[0] * 5 / 6, 0.4, "Fault", fontsize=5, color='black', verticalalignment='top',
horizontalalignment='center', horizontalalignment='center',
bbox={'facecolor': 'grey', bbox={'facecolor': 'grey',
@ -182,8 +176,8 @@ def plot_result(result_data):
# pad调整label与坐标轴之间的距离 # pad调整label与坐标轴之间的距离
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1) plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
# plt.yticks([index for index in indices1], classes1) # plt.yticks([index for index in indices1], classes1)
plt.ylabel('Confidence', fontsize=5) plt.ylabel('Confidence', fontdict=font2)
plt.xlabel('Time', fontsize=5) plt.xlabel('Time', fontdict=font2)
plt.tight_layout() plt.tight_layout()
# plt.legend(loc='best', edgecolor='black', fontsize=4) # plt.legend(loc='best', edgecolor='black', fontsize=4)
plt.legend(loc='best', frameon=False, fontsize=4.5) plt.legend(loc='best', frameon=False, fontsize=4.5)
@ -193,8 +187,8 @@ def plot_result(result_data):
axins = inset_axes(ax, width="40%", height="30%", loc='lower left', axins = inset_axes(ax, width="40%", height="30%", loc='lower left',
bbox_to_anchor=(0.1, 0.1, 1, 1), bbox_to_anchor=(0.1, 0.1, 1, 1),
bbox_transform=ax.transAxes) bbox_transform=ax.transAxes)
axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001, label="predict") axins.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.005, label="predict")
axins.axvline(result_data.shape[0] * 2 / 3-50, c='blue', ls='-.', lw=0.5, label='real fault') axins.axvline(result_data.shape[0] * 2 / 3-15, c='blue', ls='-.', lw=0.5, label='real fault')
plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5) plt.axhline(0.5, c='red', label='Failure threshold', lw=0.5)
# 设置放大区间 # 设置放大区间
# 设置放大区间 # 设置放大区间
@ -587,6 +581,48 @@ def plot_mse(file_name_mse="../others_idea/mse",data:str=''):
plt.show() plt.show()
def plot_mse_single(file_name_mse="../self_try/compare/mse/JM_banda/banda_joint_result_predict1.csv"):
mse = np.loadtxt(file_name_mse, delimiter=",")
print("mse:",mse.shape)
need_shape=int(mse.shape[0]*2/3)
# mse = mse[2000:2300]
# mse = mse[1800:2150]
# mse = mse[ need_shape+100:need_shape+377]
mse = mse[ need_shape-300:need_shape-10]
parameters = {
'figure.dpi': 600,
'figure.figsize': (2.8, 2),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 6,
'ytick.labelsize': 6,
'legend.fontsize': 5,
'font.family':'Times New Roman'
}
plt.rcParams.update(parameters)
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 7} # 设置坐标标签的字体大小,字体
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
indices = [mse.shape[0] * i / 4 for i in range(5)]
# classes = ['07:21','08:21', '09:21', '10:21', '11:21']
classes = ['01:58','02:58', '03:58', '04:58', '05:58']
plt.xticks([index for index in indices], classes, rotation=25) # 设置横坐标方向rotation=45为45度倾斜
# pad调整label与坐标轴之间的距离
plt.tick_params(bottom=True, top=False, left=True, right=False, direction='inout', length=2, width=0.5, pad=1)
plt.ylabel('Residual', fontdict=font2)
plt.xlabel('Time', fontdict=font2)
plt.tight_layout()
plt.plot(mse[:], lw=0.5)
plt.show()
def plot_3d(): def plot_3d():
@ -646,7 +682,7 @@ def test_result(file_name: str = result_file_name):
# result_data = np.concatenate([result_data, data], axis=0) # result_data = np.concatenate([result_data, data], axis=0)
print(result_data) print(result_data)
print(result_data.shape) print(result_data.shape)
plot_result_banda(result_data) plot_result(result_data)
def test_mse(mse_file_name: str = mse_file_name, max_file_name: str = max_file_name): def test_mse(mse_file_name: str = mse_file_name, max_file_name: str = max_file_name):
@ -739,17 +775,17 @@ def test_model_visualization(model_name=file_name):
if __name__ == '__main__': if __name__ == '__main__':
# test_mse(fi) # test_mse(fi)
# test_result( # test_result(
# file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_banda_result3.csv') # file_name='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\ResNet\ResNet_timestamp120_feature10_result.csv')
# test_corr() # test_corr()
# acc() # acc()
# list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02] # list = [3.77, 2.64, 2.35, 2.05, 1.76, 1.09, 0.757, 0.82, 1.1, 0.58, 0, 0.03, 0.02]
# test_bar(list) # test_bar(list)
# list=[99.99,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74] list=[98.56,98.95,99.95,96.1,95,99.65,76.25,72.64,75.87,68.74]
# plot_FNR1(list) plot_FNR1(list)
# # # #
list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6] # list=[3.43,1.99,1.92,2.17,1.63,1.81,1.78,1.8,0.6]
plot_FNR2(list) # plot_FNR2(list)
# 查看网络某一层的权重 # 查看网络某一层的权重
# test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt") # test_model_visualization(model_name = "E:\跑模型\论文写作/SE.txt")
@ -759,8 +795,8 @@ if __name__ == '__main__':
# test_model_weight_l(file_name) # test_model_weight_l(file_name)
# 单独预测图 # 单独预测图
plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv') # plot_mse('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/banda_joint_result_predict3.csv',data='E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\JM_banda/raw_data.csv')
# plot_mse_single('E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_D/banda\RNet_D_banda_mse_predict1.csv')
#画3d图 #画3d图
# plot_3d() # plot_3d()

View File

@ -283,76 +283,7 @@ def showResult(step_two_model: tf.keras.Model, test_data, isPlot: bool = False,
return total_result return total_result
def plot_result(result_data):
parameters = {
'figure.dpi': 600,
'figure.figsize': (5, 5),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=10)
# 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold')
# 箭头指向上面的水平线
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
# alpha=0.9, overhang=0.5)
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
verticalalignment='top')
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
indices = range(result_data.shape[0])
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
plt.tight_layout()
plt.show()
pass
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
sns.set(font_scale=1.5)
parameters = {
'figure.dpi': 600,
'figure.figsize': (5, 5),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
confusion = confusion_matrix(true_labels, predict_labels)
# confusion = confusion.astype('float') / confusion.sum(axis=1)[: np.newaxis] plt.figure()
# sns.heatmap(confusion, annot=True,fmt="d",cmap="Greens")
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
indices = range(len(confusion))
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
# classes.append(str(i))
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
plt.tight_layout()
plt.show()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -204,12 +204,21 @@ def test_corr(file_name=source_path,N=10):
pass pass
def plot_raw_data():
# data, label = read_data(file_name='G:\data\SCADA数据\jb4q_8_delete_total_zero.csv', isNew=False)
# data = np.loadtxt('G:\data\SCADA数据/normalization.csv',delimiter=',')
data = pd.read_csv('G:\data\SCADA数据/normalization.csv')
print(data.shape)
print(data)
plt.plot(data)
plt.show()
if __name__ == '__main__': if __name__ == '__main__':
# test_mse() # test_mse()
# test_result() # test_result()
test_corr() # test_corr()
plot_raw_data()
pass pass

View File

@ -28,12 +28,12 @@ class Joint_Monitoring(keras.Model):
super(Joint_Monitoring, self).__init__() super(Joint_Monitoring, self).__init__()
# step one # step one
self.RepDCBlock1 = RevConvBlock(num=3, kernel_size=5) self.RepDCBlock1 = RevConvBlock(num=3, kernel_size=5)
self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME') self.conv1 = tf.keras.layers.Conv1D(filters=conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
self.upsample1 = tf.keras.layers.UpSampling1D(size=2) self.upsample1 = tf.keras.layers.UpSampling1D(size=2)
self.DACU2 = DynamicChannelAttention() self.DACU2 = DynamicChannelAttention()
self.RepDCBlock2 = RevConvBlock(num=3, kernel_size=3) self.RepDCBlock2 = RevConvBlock(num=3, kernel_size=3)
self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME') self.conv2 = tf.keras.layers.Conv1D(filters=2 * conv_filter, kernel_size=1, strides=2, padding='SAME',kernel_initializer=0.7,bias_initializer=1)
self.upsample2 = tf.keras.layers.UpSampling1D(size=2) self.upsample2 = tf.keras.layers.UpSampling1D(size=2)
self.DACU3 = DynamicChannelAttention() self.DACU3 = DynamicChannelAttention()
@ -382,16 +382,18 @@ class Joint_Monitoring(keras.Model):
class RevConv(keras.layers.Layer): class RevConv(keras.layers.Layer):
def __init__(self, kernel_size=3): def __init__(self, kernel_size=3,dilation_rate=2):
# 调用父类__init__()方法 # 调用父类__init__()方法
super(RevConv, self).__init__() super(RevConv, self).__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
def get_config(self): def get_config(self):
# 自定义层里面的属性 # 自定义层里面的属性
config = ( config = (
{ {
'kernel_size': self.kernel_size 'kernel_size': self.kernel_size,
'dilation_rate': self.dilation_rate
} }
) )
base_config = super(RevConv, self).get_config() base_config = super(RevConv, self).get_config()
@ -402,10 +404,10 @@ class RevConv(keras.layers.Layer):
_, _, output_dim = input_shape[0], input_shape[1], input_shape[2] _, _, output_dim = input_shape[0], input_shape[1], input_shape[2]
self.conv1 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=self.kernel_size, strides=1, self.conv1 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=self.kernel_size, strides=1,
padding='causal', padding='causal',
dilation_rate=4) dilation_rate=self.dilation_rate)
self.conv2 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=1, strides=1, padding='causal', self.conv2 = tf.keras.layers.Conv1D(filters=output_dim, kernel_size=1, strides=1, padding='causal',
dilation_rate=4) dilation_rate=self.dilation_rate)
# self.b2 = tf.keras.layers.BatchNormalization() # self.b2 = tf.keras.layers.BatchNormalization()
# self.b3 = tf.keras.layers.BatchNormalization() # self.b3 = tf.keras.layers.BatchNormalization()
@ -440,7 +442,7 @@ class RevConvBlock(keras.layers.Layer):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.L = [] self.L = []
for i in range(num): for i in range(num):
RepVGG = RevConv(kernel_size=kernel_size) RepVGG = RevConv(kernel_size=kernel_size,dilation_rate=(i+1)*2)
self.L.append(RepVGG) self.L.append(RepVGG)
def get_config(self): def get_config(self):