leecode更新
This commit is contained in:
parent
1b92547add
commit
10c949ab6c
|
|
@ -17,6 +17,19 @@
|
|||
<groupId>com.atguigu</groupId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.2</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,4 +8,5 @@ mysql.username=root
|
|||
mysql.password=123456
|
||||
|
||||
# clickhouseÅäÖÃ
|
||||
clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
|
||||
#clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
|
||||
clickhouse.url=jdbc:clickhouse://localhost:8123/default
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
import com.atguigu.userprofile.common.utils.ClickhouseUtils
|
||||
|
||||
object clickhouseJDBC {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
|
||||
val createSQL =
|
||||
s"""
|
||||
|create table t_order_mt(
|
||||
| uid UInt32,
|
||||
| sku_id String,
|
||||
| total_amount Decimal(16,2),
|
||||
| create_time Datetime
|
||||
| ) engine =MergeTree
|
||||
| partition by toYYYYMMDD(create_time)
|
||||
| primary key (uid)
|
||||
| order by (uid,sku_id)
|
||||
|""".stripMargin
|
||||
|
||||
ClickhouseUtils.executeSql(createSQL)
|
||||
|
||||
val insertSQL=
|
||||
"""
|
||||
|insert into t_order_mt
|
||||
|values(101,'sku_001',1000.00,'2020-06-01 12:00:00') ,
|
||||
|(102,'sku_002',2000.00,'2020-06-01 11:00:00'),
|
||||
|(102,'sku_004',2500.00,'2020-06-01 12:00:00'),
|
||||
|(102,'sku_002',2000.00,'2020-06-01 13:00:00')
|
||||
|(102,'sku_002',12000.00,'2020-06-01 13:00:00')
|
||||
|(102,'sku_002',600.00,'2020-06-02 12:00:00')
|
||||
|""".stripMargin
|
||||
|
||||
ClickhouseUtils.executeSql(insertSQL)
|
||||
|
||||
|
||||
val selectSQL=
|
||||
"""
|
||||
|select * from t_order_mt
|
||||
|""".stripMargin
|
||||
|
||||
ClickhouseUtils.executeSql(selectSQL)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import com.atguigu.userprofile.common.utils.ClickhouseUtils;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* @BelongsProject: user-profile-manager0111
|
||||
* @BelongsPackage: PACKAGE_NAME
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-10-20 16:58
|
||||
* @Description: TODO 尝试连接clickhouse
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class clickhouseJDNC {
|
||||
|
||||
|
||||
@Test
|
||||
|
||||
public void testClickhouse(){
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
package com.markilue.leecode.backtrace;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @BelongsProject: Leecode
|
||||
* @BelongsPackage: com.markilue.leecode.backtrace
|
||||
* @Author: markilue
|
||||
* @CreateTime: 2022-10-21 10:11
|
||||
* @Description: TODO 力扣37题 解数独:
|
||||
* 编写一个程序,通过填充空格来解决数独问题。
|
||||
* 数独的解法需 遵循如下规则:
|
||||
* 数字 1-9 在每一行只能出现一次。
|
||||
* 数字 1-9 在每一列只能出现一次。
|
||||
* 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。(请参考示例图)
|
||||
* 数独部分空格内已填入了数字,空白格用 '.' 表示。
|
||||
* @Version: 1.0
|
||||
*/
|
||||
public class SolveSudoku {
|
||||
|
||||
@Test
|
||||
public void test() {
|
||||
char a = '1';
|
||||
char[][] board = {
|
||||
{'5', '3', '.', '.', '7', '.', '.', '.', '.'},
|
||||
{'6', '.', '.', '1', '9', '5', '.', '.', '.'},
|
||||
{'.', '9', '8', '.', '.', '.', '.', '6', '.'},
|
||||
{'8', '.', '.', '.', '6', '.', '.', '.', '3'},
|
||||
{'4', '.', '.', '8', '.', '3', '.', '.', '1'},
|
||||
{'7', '.', '.', '.', '2', '.', '.', '.', '6'},
|
||||
{'.', '6', '.', '.', '.', '.', '2', '8', '.'},
|
||||
{'.', '.', '.', '4', '1', '9', '.', '.', '5'},
|
||||
{'.', '.', '.', '.', '8', '.', '.', '7', '9'}};
|
||||
|
||||
solveSudoku(board);
|
||||
for (int i = 0; i < board.length; i++) {
|
||||
System.out.println(Arrays.toString(board[i]));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public void solveSudoku(char[][] board) {
|
||||
backtracking1(board);
|
||||
}
|
||||
|
||||
/**
|
||||
* 自己思路:填写i,j位置的值
|
||||
* 存在问题,似乎存在问题,出错之后会一直return回溯了?
|
||||
* @param board
|
||||
* @param i
|
||||
* @param j
|
||||
*/
|
||||
public void backtracking(char[][] board, int i, int j) {
|
||||
|
||||
if (i == board.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (j == board.length) {
|
||||
backtracking(board, i + 1, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
if(board[i][j]!='.'){
|
||||
backtracking(board,i,j+1);
|
||||
return;
|
||||
}
|
||||
|
||||
for (int k = 1; k <= 9; k++) {
|
||||
|
||||
if (!check(board, i, j, k)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
board[i][j] = (char) ('0' + k);
|
||||
backtracking(board, i, j + 1);
|
||||
// board[i][j] = '.';
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 代码随想录思路:
|
||||
* 速度击败6.41%,内存击败65.83%
|
||||
* @param board
|
||||
* @return
|
||||
*/
|
||||
public boolean backtracking1(char[][] board) {
|
||||
|
||||
for (int i = 0; i < board.length; i++) {
|
||||
for (int j = 0; j < board[0].length; j++) {
|
||||
if(board[i][j]!='.')continue;
|
||||
for (char k = '1'; k <= '9'; k++) {
|
||||
if(check1(board,i,j,k)){
|
||||
board[i][j]=k;
|
||||
//找到第一个直接return
|
||||
if(backtracking1(board))return true;
|
||||
board[i][j]='.'; //回溯撤销k
|
||||
}
|
||||
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
return true; //遍历完了都没有返回,那么就说明找到了合适的位置
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* i,j分别是需要填写的横坐标和纵坐标,value是要填入的值
|
||||
*
|
||||
* @param board
|
||||
* @param i
|
||||
* @param j
|
||||
* @return
|
||||
*/
|
||||
public boolean check1(char[][] board, int i, int j, char value) {
|
||||
|
||||
// //测试当前位置
|
||||
// if (board[i][j] != '.') {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
//监测横轴
|
||||
for (int k = 0; k < board[i].length; k++) {
|
||||
if (board[i][k] == value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
//监测纵轴
|
||||
for (int k = 0; k < board[j].length; k++) {
|
||||
if ( board[k][j] == value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
//监测九宫格
|
||||
int a = (i / 3)*3;
|
||||
int b = (j / 3)*3;
|
||||
|
||||
|
||||
for (int k = a; k < a+2; k++) {
|
||||
for (int l = b; l <= b+2; l++) {
|
||||
if ( board[k][l] == value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* i,j分别是需要填写的横坐标和纵坐标,value是要填入的值
|
||||
*
|
||||
* @param board
|
||||
* @param i
|
||||
* @param j
|
||||
* @return
|
||||
*/
|
||||
public boolean check(char[][] board, int i, int j, int value) {
|
||||
|
||||
// //测试当前位置
|
||||
// if (board[i][j] != '.') {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
//监测横轴
|
||||
for (int k = 0; k < board[i].length; k++) {
|
||||
if (board[i][k] != '.' && board[i][k] - '0' == value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
//监测纵轴
|
||||
for (int k = 0; k < board[j].length; k++) {
|
||||
if (board[k][j] != '.' && board[k][j] - '0' == value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
//监测九宫格
|
||||
int a = i % 3;
|
||||
int b = j % 3;
|
||||
int xstart = 0;
|
||||
int xend = 0;
|
||||
int ystart = 0;
|
||||
int yend = 0;
|
||||
if (a == 0) {
|
||||
xstart = i + 1;
|
||||
xend = i + 2;
|
||||
} else if (a == 1) {
|
||||
xstart = i - 1;
|
||||
xend = i + 1;
|
||||
} else {
|
||||
xstart = i - 2;
|
||||
xend = i - 1;
|
||||
}
|
||||
|
||||
if (b == 0) {
|
||||
ystart = j + 1;
|
||||
yend = j + 2;
|
||||
} else if (b == 1) {
|
||||
ystart = j - 1;
|
||||
yend = j + 1;
|
||||
} else {
|
||||
ystart = j - 2;
|
||||
yend = j - 1;
|
||||
}
|
||||
|
||||
|
||||
for (int k = xstart; k <= xend; k++) {
|
||||
for (int l = ystart; l <= yend; l++) {
|
||||
if (board[k][l] != '.' && board[k][l] - '0' == value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 官方回溯法:本质上就是记录各个位置使用的数组的情况,然后根据条件:
|
||||
* line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = false才能填入
|
||||
* 需要填写的位置通过spaces记录下来,然后按个去填写
|
||||
* 官方还有两种优化方式:位运算方式和枚举优化,但是都不是算法层面的优化,此处不一一写
|
||||
* 位运算优化:分别是使用int[i]=k数组去记录位置i填入的数字k
|
||||
* 枚举优化:如果一个空白格只有唯一的数可以填入,也就是其对应的 bb 值和 b-1b−1 进行按位与运算后得到 00(即 bb 中只有一个二进制位为 11)。此时,我们就可以确定这个空白格填入的数,而不用等到递归时再去处理它。
|
||||
*
|
||||
*/
|
||||
private boolean[][] line = new boolean[9][9];
|
||||
private boolean[][] column = new boolean[9][9];
|
||||
private boolean[][][] block = new boolean[3][3][9];
|
||||
private boolean valid = false;
|
||||
private List<int[]> spaces = new ArrayList<int[]>();
|
||||
|
||||
public void solveSudoku1(char[][] board) {
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
for (int j = 0; j < 9; ++j) {
|
||||
if (board[i][j] == '.') {
|
||||
spaces.add(new int[]{i, j});
|
||||
} else {
|
||||
int digit = board[i][j] - '0' - 1;
|
||||
line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dfs(board, 0);
|
||||
}
|
||||
|
||||
public void dfs(char[][] board, int pos) {
|
||||
if (pos == spaces.size()) {
|
||||
valid = true;
|
||||
return;
|
||||
}
|
||||
|
||||
int[] space = spaces.get(pos);
|
||||
int i = space[0], j = space[1];
|
||||
for (int digit = 0; digit < 9 && !valid; ++digit) {
|
||||
if (!line[i][digit] && !column[j][digit] && !block[i / 3][j / 3][digit]) {
|
||||
line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = true;
|
||||
board[i][j] = (char) (digit + '0' + 1);
|
||||
dfs(board, pos + 1);
|
||||
line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -21,23 +21,23 @@ K = 18
|
|||
namuda = 0.01
|
||||
'''保存名称'''
|
||||
|
||||
save_name = "./model//{0}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_name = "./model/{0}.h5".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
|
||||
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
save_mse_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
EPOCH)
|
||||
save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
save_max_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
|
||||
time_stamp,
|
||||
feature_num,
|
||||
batch_size,
|
||||
|
|
@ -220,11 +220,134 @@ def resnet_Model():
|
|||
|
||||
dense = tf.keras.layers.Dense(128, activation=tf.nn.relu)(dropout)
|
||||
dense = tf.keras.layers.BatchNormalization(name="bn_last")(dense)
|
||||
dense = tf.keras.layers.Dense(2, activation=tf.nn.sigmoid)(dense)
|
||||
dense = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(dense)
|
||||
model = tf.keras.Model(inputs=inputs, outputs=dense)
|
||||
return model
|
||||
|
||||
|
||||
def showResult(step_two_model: tf.keras.Model, test_data, isPlot: bool = False, isSave: bool = True):
|
||||
# 获取模型的所有参数的个数
|
||||
# step_two_model.count_params()
|
||||
total_result = []
|
||||
|
||||
size, length, dims = test_data.shape
|
||||
predict_label = step_two_model.predict(test_data)
|
||||
total_result = np.reshape(total_result, [total_result.__len__(), -1])
|
||||
total_result = np.reshape(total_result, [-1, ])
|
||||
|
||||
# 误报率,漏报率,准确性的计算
|
||||
|
||||
if isSave:
|
||||
np.savetxt(save_mse_name, total_result, delimiter=',')
|
||||
if isPlot:
|
||||
plt.figure(1, figsize=(6.0, 2.68))
|
||||
plt.subplots_adjust(left=0.1, right=0.94, bottom=0.2, top=0.9, wspace=None,
|
||||
hspace=None)
|
||||
plt.tight_layout()
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 10} # 设置坐标标签的字体大小,字体
|
||||
|
||||
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
|
||||
verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
plt.tick_params() # 设置轴显示
|
||||
plt.xlabel("time", fontdict=font1)
|
||||
plt.ylabel("confience", fontdict=font1)
|
||||
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10}, fontdict=font1)
|
||||
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 10}, fontdict=font1)
|
||||
|
||||
plt.grid()
|
||||
# plt.ylim(0, 1)
|
||||
# plt.xlim(-50, 1300)
|
||||
# plt.legend("", loc='upper left')
|
||||
plt.show()
|
||||
return total_result
|
||||
|
||||
|
||||
def plot_result(result_data):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (5, 5),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 20,
|
||||
'ytick.labelsize': 20,
|
||||
'legend.fontsize': 11.3,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=10)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
|
||||
verticalalignment='top')
|
||||
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
|
||||
|
||||
indices = range(result_data.shape[0])
|
||||
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=20)
|
||||
plt.xlabel('Predicted label', fontsize=20)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
|
||||
sns.set(font_scale=1.5)
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (5, 5),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 20,
|
||||
'ytick.labelsize': 20,
|
||||
'legend.fontsize': 11.3,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
confusion = confusion_matrix(true_labels, predict_labels)
|
||||
# confusion = confusion.astype('float') / confusion.sum(axis=1)[: np.newaxis] plt.figure()
|
||||
# sns.heatmap(confusion, annot=True,fmt="d",cmap="Greens")
|
||||
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
|
||||
indices = range(len(confusion))
|
||||
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
|
||||
# classes.append(str(i))
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=20)
|
||||
plt.xlabel('Predicted label', fontsize=20)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# # 数据读入
|
||||
#
|
||||
|
|
@ -251,60 +374,68 @@ if __name__ == '__main__':
|
|||
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.binary_crossentropy,
|
||||
metrics=['acc'])
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, mode='min', verbose=1)
|
||||
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=1,
|
||||
save_best_only=True,
|
||||
mode='min',
|
||||
period=1)
|
||||
|
||||
history = model.fit(train_data, train_label, epochs=20, batch_size=32, validation_data=(test_data, test_label),
|
||||
callbacks=[checkpoint, early_stop])
|
||||
model.save("./model/ResNet.h5")
|
||||
model = tf.keras.models.load_model("../model/ResNet_model.h5")
|
||||
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, mode='min', verbose=1)
|
||||
#
|
||||
# checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
# filepath=save_name,
|
||||
# monitor='val_loss',
|
||||
# verbose=1,
|
||||
# save_best_only=False,
|
||||
# mode='min',
|
||||
# period=1)
|
||||
#
|
||||
# history = model.fit(train_data, train_label, epochs=20, batch_size=16, validation_data=(test_data, test_label),
|
||||
# callbacks=[checkpoint, early_stop])
|
||||
# model.save("./model/ResNet.h5")
|
||||
model = tf.keras.models.load_model("./model/ResNet_2_9906.h5")
|
||||
|
||||
# 结果展示
|
||||
trained_data = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('bn_last').output).predict(
|
||||
train_data)
|
||||
predict_label = model.predict(test_data)
|
||||
predict_label_max = np.argmax(predict_label, axis=1)
|
||||
predict_label = np.expand_dims(predict_label_max, axis=1)
|
||||
|
||||
confusion_matrix = confusion_matrix(test_label, predict_label)
|
||||
healthy_size, _, _ = train_data_healthy.shape
|
||||
unhealthy_size, _, _ = train_data_unhealthy.shape
|
||||
all_data, _, _ = get_training_data_overlapping(
|
||||
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
|
||||
|
||||
tsne = TSNE(n_components=3, verbose=2, perplexity=30, n_iter=5000).fit_transform(trained_data)
|
||||
print("tsne[:,0]", tsne[:, 0])
|
||||
print("tsne[:,1]", tsne[:, 1])
|
||||
print("tsne[:,2]", tsne[:, 2])
|
||||
x, y, z = tsne[:, 0], tsne[:, 1], tsne[:, 2]
|
||||
x = (x - np.min(x)) / (np.max(x) - np.min(x))
|
||||
y = (y - np.min(y)) / (np.max(y) - np.min(y))
|
||||
z = (z - np.min(z)) / (np.max(z) - np.min(z))
|
||||
showResult(step_two_model, test_data=all_data, isPlot=True)
|
||||
|
||||
fig1 = plt.figure()
|
||||
ax1 = fig1.add_subplot(projection='3d')
|
||||
ax1.scatter3D(x, y, z, c=train_label, cmap=plt.cm.get_cmap("jet", 10))
|
||||
|
||||
fig2 = plt.figure()
|
||||
ax2 = fig2.add_subplot()
|
||||
sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
|
||||
plt.ylabel('Actual label')
|
||||
plt.xlabel('Predicted label')
|
||||
|
||||
# fig3 = plt.figure()
|
||||
# ax3 = fig3.add_subplot()
|
||||
# plt.plot(history.epoch, history.history.get('acc'), label='acc')
|
||||
# plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
|
||||
# trained_data = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('bn_last').output).predict(
|
||||
# train_data)
|
||||
# predict_label = model.predict(test_data)
|
||||
# predict_label_max = np.argmax(predict_label, axis=1)
|
||||
# predict_label = np.expand_dims(predict_label_max, axis=1)
|
||||
#
|
||||
# fig4 = plt.figure()
|
||||
# ax4 = fig3.add_subplot()
|
||||
# plt.plot(history.epoch, history.history.get('loss'), label='loss')
|
||||
# plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
score = model.evaluate(test_data, test_label)
|
||||
print('score:', score)
|
||||
# confusion_matrix = confusion_matrix(test_label, predict_label)
|
||||
#
|
||||
# tsne = TSNE(n_components=3, verbose=2, perplexity=30, n_iter=5000).fit_transform(trained_data)
|
||||
# print("tsne[:,0]", tsne[:, 0])
|
||||
# print("tsne[:,1]", tsne[:, 1])
|
||||
# print("tsne[:,2]", tsne[:, 2])
|
||||
# x, y, z = tsne[:, 0], tsne[:, 1], tsne[:, 2]
|
||||
# x = (x - np.min(x)) / (np.max(x) - np.min(x))
|
||||
# y = (y - np.min(y)) / (np.max(y) - np.min(y))
|
||||
# z = (z - np.min(z)) / (np.max(z) - np.min(z))
|
||||
#
|
||||
# fig1 = plt.figure()
|
||||
# ax1 = fig1.add_subplot(projection='3d')
|
||||
# ax1.scatter3D(x, y, z, c=train_label, cmap=plt.cm.get_cmap("jet", 10))
|
||||
#
|
||||
# fig2 = plt.figure()
|
||||
# ax2 = fig2.add_subplot()
|
||||
# sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
|
||||
# plt.ylabel('Actual label')
|
||||
# plt.xlabel('Predicted label')
|
||||
#
|
||||
# # fig3 = plt.figure()
|
||||
# # ax3 = fig3.add_subplot()
|
||||
# # plt.plot(history.epoch, history.history.get('acc'), label='acc')
|
||||
# # plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
|
||||
# #
|
||||
# # fig4 = plt.figure()
|
||||
# # ax4 = fig3.add_subplot()
|
||||
# # plt.plot(history.epoch, history.history.get('loss'), label='loss')
|
||||
# # plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
|
||||
# plt.legend()
|
||||
# plt.show()
|
||||
#
|
||||
# score = model.evaluate(test_data, test_label)
|
||||
# print('score:', score)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# coding: utf-8
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2022/10/20 21:35
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
||||
|
||||
def plot_result(result_data):
|
||||
parameters = {
|
||||
'figure.dpi': 600,
|
||||
'figure.figsize': (2.5, 2),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 5,
|
||||
'ytick.labelsize': 5,
|
||||
'legend.fontsize': 5,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
|
||||
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
|
||||
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001)
|
||||
# 画出 y=1 这条水平线
|
||||
plt.axhline(0.5, c='red', label='Failure threshold')
|
||||
# 箭头指向上面的水平线
|
||||
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
|
||||
# alpha=0.9, overhang=0.5)
|
||||
plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=5, color='red',
|
||||
verticalalignment='top')
|
||||
plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.')
|
||||
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
|
||||
plt.text(result_data.shape[0] * 4 / 5, 0.6, "Fault", fontsize=5, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 2.5}, fontdict=font1)
|
||||
plt.text(result_data.shape[0] * 1 / 3, 0.4, "Norm", fontsize=5, color='black', verticalalignment='top',
|
||||
horizontalalignment='center',
|
||||
bbox={'facecolor': 'grey',
|
||||
'pad': 2.5}, fontdict=font1)
|
||||
|
||||
|
||||
indices = [result_data.shape[0] *i / 4 for i in range(5)]
|
||||
print(indices)
|
||||
classes = ['N', 'IF', 'OF','TRC', 'oo'] # for i in range(cls):
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=5)
|
||||
plt.xlabel('Predicted label', fontsize=5)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
file_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_C\RNet_C_timestamp120_feature10_result.csv"
|
||||
# result_data = np.recfromcsv(file_name)
|
||||
result_data = np.loadtxt(file_name, delimiter=",")
|
||||
result_data = np.array(result_data)
|
||||
data = np.zeros([208, ])
|
||||
result_data = np.concatenate([result_data, data], axis=0)
|
||||
print(result_data)
|
||||
print(result_data.shape)
|
||||
plot_result(result_data)
|
||||
Loading…
Reference in New Issue