leecode更新

This commit is contained in:
markilue 2022-10-21 13:41:39 +08:00
parent 1b92547add
commit 10c949ab6c
7 changed files with 627 additions and 60 deletions

View File

@ -17,6 +17,19 @@
<groupId>com.atguigu</groupId>
<version>1.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>compile</scope>
</dependency>
</dependencies>

View File

@ -8,4 +8,5 @@ mysql.username=root
mysql.password=123456
# clickhouseÅäÖÃ
clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
#clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
clickhouse.url=jdbc:clickhouse://localhost:8123/default

View File

@ -0,0 +1,44 @@
import com.atguigu.userprofile.common.utils.ClickhouseUtils
object clickhouseJDBC {
def main(args: Array[String]): Unit = {
val createSQL =
s"""
|create table t_order_mt(
| uid UInt32,
| sku_id String,
| total_amount Decimal(16,2),
| create_time Datetime
| ) engine =MergeTree
| partition by toYYYYMMDD(create_time)
| primary key (uid)
| order by (uid,sku_id)
|""".stripMargin
ClickhouseUtils.executeSql(createSQL)
val insertSQL=
"""
|insert into t_order_mt
|values(101,'sku_001',1000.00,'2020-06-01 12:00:00') ,
|(102,'sku_002',2000.00,'2020-06-01 11:00:00'),
|(102,'sku_004',2500.00,'2020-06-01 12:00:00'),
|(102,'sku_002',2000.00,'2020-06-01 13:00:00')
|(102,'sku_002',12000.00,'2020-06-01 13:00:00')
|(102,'sku_002',600.00,'2020-06-02 12:00:00')
|""".stripMargin
ClickhouseUtils.executeSql(insertSQL)
val selectSQL=
"""
|select * from t_order_mt
|""".stripMargin
ClickhouseUtils.executeSql(selectSQL)
}
}

View File

@ -0,0 +1,20 @@
import com.atguigu.userprofile.common.utils.ClickhouseUtils;
import org.junit.Test;
/**
* @BelongsProject: user-profile-manager0111
* @BelongsPackage: PACKAGE_NAME
* @Author: markilue
* @CreateTime: 2022-10-20 16:58
* @Description: TODO 尝试连接clickhouse
* @Version: 1.0
*/
public class clickhouseJDNC {
@Test
public void testClickhouse(){
}
}

View File

@ -0,0 +1,285 @@
package com.markilue.leecode.backtrace;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @BelongsProject: Leecode
* @BelongsPackage: com.markilue.leecode.backtrace
* @Author: markilue
* @CreateTime: 2022-10-21 10:11
* @Description: TODO 力扣37题 解数独
* 编写一个程序通过填充空格来解决数独问题
* 数独的解法需 遵循如下规则
* 数字 1-9 在每一行只能出现一次
* 数字 1-9 在每一列只能出现一次
* 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次请参考示例图
* 数独部分空格内已填入了数字空白格用 '.' 表示
* @Version: 1.0
*/
public class SolveSudoku {
@Test
public void test() {
char a = '1';
char[][] board = {
{'5', '3', '.', '.', '7', '.', '.', '.', '.'},
{'6', '.', '.', '1', '9', '5', '.', '.', '.'},
{'.', '9', '8', '.', '.', '.', '.', '6', '.'},
{'8', '.', '.', '.', '6', '.', '.', '.', '3'},
{'4', '.', '.', '8', '.', '3', '.', '.', '1'},
{'7', '.', '.', '.', '2', '.', '.', '.', '6'},
{'.', '6', '.', '.', '.', '.', '2', '8', '.'},
{'.', '.', '.', '4', '1', '9', '.', '.', '5'},
{'.', '.', '.', '.', '8', '.', '.', '7', '9'}};
solveSudoku(board);
for (int i = 0; i < board.length; i++) {
System.out.println(Arrays.toString(board[i]));
}
}
public void solveSudoku(char[][] board) {
backtracking1(board);
}
/**
* 自己思路填写i,j位置的值
* 存在问题似乎存在问题出错之后会一直return回溯了
* @param board
* @param i
* @param j
*/
public void backtracking(char[][] board, int i, int j) {
if (i == board.length) {
return;
}
if (j == board.length) {
backtracking(board, i + 1, 0);
return;
}
if(board[i][j]!='.'){
backtracking(board,i,j+1);
return;
}
for (int k = 1; k <= 9; k++) {
if (!check(board, i, j, k)) {
continue;
}
board[i][j] = (char) ('0' + k);
backtracking(board, i, j + 1);
// board[i][j] = '.';
}
}
/**
* 代码随想录思路
* 速度击败6.41%内存击败65.83%
* @param board
* @return
*/
public boolean backtracking1(char[][] board) {
for (int i = 0; i < board.length; i++) {
for (int j = 0; j < board[0].length; j++) {
if(board[i][j]!='.')continue;
for (char k = '1'; k <= '9'; k++) {
if(check1(board,i,j,k)){
board[i][j]=k;
//找到第一个直接return
if(backtracking1(board))return true;
board[i][j]='.'; //回溯撤销k
}
}
return false;
}
}
return true; //遍历完了都没有返回那么就说明找到了合适的位置
}
/**
* i,j分别是需要填写的横坐标和纵坐标value是要填入的值
*
* @param board
* @param i
* @param j
* @return
*/
public boolean check1(char[][] board, int i, int j, char value) {
// //测试当前位置
// if (board[i][j] != '.') {
// return false;
// }
//监测横轴
for (int k = 0; k < board[i].length; k++) {
if (board[i][k] == value) {
return false;
}
}
//监测纵轴
for (int k = 0; k < board[j].length; k++) {
if ( board[k][j] == value) {
return false;
}
}
//监测九宫格
int a = (i / 3)*3;
int b = (j / 3)*3;
for (int k = a; k < a+2; k++) {
for (int l = b; l <= b+2; l++) {
if ( board[k][l] == value) {
return false;
}
}
}
return true;
}
/**
* i,j分别是需要填写的横坐标和纵坐标value是要填入的值
*
* @param board
* @param i
* @param j
* @return
*/
public boolean check(char[][] board, int i, int j, int value) {
// //测试当前位置
// if (board[i][j] != '.') {
// return false;
// }
//监测横轴
for (int k = 0; k < board[i].length; k++) {
if (board[i][k] != '.' && board[i][k] - '0' == value) {
return false;
}
}
//监测纵轴
for (int k = 0; k < board[j].length; k++) {
if (board[k][j] != '.' && board[k][j] - '0' == value) {
return false;
}
}
//监测九宫格
int a = i % 3;
int b = j % 3;
int xstart = 0;
int xend = 0;
int ystart = 0;
int yend = 0;
if (a == 0) {
xstart = i + 1;
xend = i + 2;
} else if (a == 1) {
xstart = i - 1;
xend = i + 1;
} else {
xstart = i - 2;
xend = i - 1;
}
if (b == 0) {
ystart = j + 1;
yend = j + 2;
} else if (b == 1) {
ystart = j - 1;
yend = j + 1;
} else {
ystart = j - 2;
yend = j - 1;
}
for (int k = xstart; k <= xend; k++) {
for (int l = ystart; l <= yend; l++) {
if (board[k][l] != '.' && board[k][l] - '0' == value) {
return false;
}
}
}
return true;
}
/**
* 官方回溯法本质上就是记录各个位置使用的数组的情况然后根据条件
* line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = false才能填入
* 需要填写的位置通过spaces记录下来然后按个去填写
* 官方还有两种优化方式位运算方式和枚举优化但是都不是算法层面的优化此处不一一写
* 位运算优化分别是使用int[i]=k数组去记录位置i填入的数字k
* 枚举优化如果一个空白格只有唯一的数可以填入也就是其对应的 bb 值和 b-1b1 进行按位与运算后得到 00 bb 中只有一个二进制位为 11此时我们就可以确定这个空白格填入的数而不用等到递归时再去处理它
*
*/
private boolean[][] line = new boolean[9][9];
private boolean[][] column = new boolean[9][9];
private boolean[][][] block = new boolean[3][3][9];
private boolean valid = false;
private List<int[]> spaces = new ArrayList<int[]>();
public void solveSudoku1(char[][] board) {
for (int i = 0; i < 9; ++i) {
for (int j = 0; j < 9; ++j) {
if (board[i][j] == '.') {
spaces.add(new int[]{i, j});
} else {
int digit = board[i][j] - '0' - 1;
line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = true;
}
}
}
dfs(board, 0);
}
public void dfs(char[][] board, int pos) {
if (pos == spaces.size()) {
valid = true;
return;
}
int[] space = spaces.get(pos);
int i = space[0], j = space[1];
for (int digit = 0; digit < 9 && !valid; ++digit) {
if (!line[i][digit] && !column[j][digit] && !block[i / 3][j / 3][digit]) {
line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = true;
board[i][j] = (char) (digit + '0' + 1);
dfs(board, pos + 1);
line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = false;
}
}
}
}

View File

@ -21,7 +21,7 @@ K = 18
namuda = 0.01
'''保存名称'''
save_name = "./model//{0}.h5".format(model_name,
save_name = "./model/{0}.h5".format(model_name,
time_stamp,
feature_num,
batch_size,
@ -32,12 +32,12 @@ save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".fo
batch_size,
EPOCH)
save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
save_mse_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
time_stamp,
feature_num,
batch_size,
EPOCH)
save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
save_max_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
time_stamp,
feature_num,
batch_size,
@ -220,11 +220,134 @@ def resnet_Model():
dense = tf.keras.layers.Dense(128, activation=tf.nn.relu)(dropout)
dense = tf.keras.layers.BatchNormalization(name="bn_last")(dense)
dense = tf.keras.layers.Dense(2, activation=tf.nn.sigmoid)(dense)
dense = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(dense)
model = tf.keras.Model(inputs=inputs, outputs=dense)
return model
def showResult(step_two_model: tf.keras.Model, test_data, isPlot: bool = False, isSave: bool = True):
# 获取模型的所有参数的个数
# step_two_model.count_params()
total_result = []
size, length, dims = test_data.shape
predict_label = step_two_model.predict(test_data)
total_result = np.reshape(total_result, [total_result.__len__(), -1])
total_result = np.reshape(total_result, [-1, ])
# 误报率,漏报率,准确性的计算
if isSave:
np.savetxt(save_mse_name, total_result, delimiter=',')
if isPlot:
plt.figure(1, figsize=(6.0, 2.68))
plt.subplots_adjust(left=0.1, right=0.94, bottom=0.2, top=0.9, wspace=None,
hspace=None)
plt.tight_layout()
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 10} # 设置坐标标签的字体大小,字体
plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
# 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold')
# 箭头指向上面的水平线
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
# alpha=0.9, overhang=0.5)
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
verticalalignment='top')
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
plt.tick_params() # 设置轴显示
plt.xlabel("time", fontdict=font1)
plt.ylabel("confience", fontdict=font1)
plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
horizontalalignment='center',
bbox={'facecolor': 'grey',
'pad': 10}, fontdict=font1)
plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
horizontalalignment='center',
bbox={'facecolor': 'grey',
'pad': 10}, fontdict=font1)
plt.grid()
# plt.ylim(0, 1)
# plt.xlim(-50, 1300)
# plt.legend("", loc='upper left')
plt.show()
return total_result
def plot_result(result_data):
parameters = {
'figure.dpi': 600,
'figure.figsize': (5, 5),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=10)
# 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold')
# 箭头指向上面的水平线
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
# alpha=0.9, overhang=0.5)
plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
verticalalignment='top')
plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
indices = range(result_data.shape[0])
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
plt.tight_layout()
plt.show()
pass
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
sns.set(font_scale=1.5)
parameters = {
'figure.dpi': 600,
'figure.figsize': (5, 5),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
confusion = confusion_matrix(true_labels, predict_labels)
# confusion = confusion.astype('float') / confusion.sum(axis=1)[: np.newaxis] plt.figure()
# sns.heatmap(confusion, annot=True,fmt="d",cmap="Greens")
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
indices = range(len(confusion))
classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
# classes.append(str(i))
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
plt.tight_layout()
plt.show()
if __name__ == '__main__':
# # 数据读入
#
@ -251,60 +374,68 @@ if __name__ == '__main__':
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.binary_crossentropy,
metrics=['acc'])
model.summary()
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, mode='min', verbose=1)
checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=save_name,
monitor='val_loss',
verbose=1,
save_best_only=True,
mode='min',
period=1)
history = model.fit(train_data, train_label, epochs=20, batch_size=32, validation_data=(test_data, test_label),
callbacks=[checkpoint, early_stop])
model.save("./model/ResNet.h5")
model = tf.keras.models.load_model("../model/ResNet_model.h5")
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, mode='min', verbose=1)
#
# checkpoint = tf.keras.callbacks.ModelCheckpoint(
# filepath=save_name,
# monitor='val_loss',
# verbose=1,
# save_best_only=False,
# mode='min',
# period=1)
#
# history = model.fit(train_data, train_label, epochs=20, batch_size=16, validation_data=(test_data, test_label),
# callbacks=[checkpoint, early_stop])
# model.save("./model/ResNet.h5")
model = tf.keras.models.load_model("./model/ResNet_2_9906.h5")
# 结果展示
trained_data = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('bn_last').output).predict(
train_data)
predict_label = model.predict(test_data)
predict_label_max = np.argmax(predict_label, axis=1)
predict_label = np.expand_dims(predict_label_max, axis=1)
confusion_matrix = confusion_matrix(test_label, predict_label)
healthy_size, _, _ = train_data_healthy.shape
unhealthy_size, _, _ = train_data_unhealthy.shape
all_data, _, _ = get_training_data_overlapping(
total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
tsne = TSNE(n_components=3, verbose=2, perplexity=30, n_iter=5000).fit_transform(trained_data)
print("tsne[:,0]", tsne[:, 0])
print("tsne[:,1]", tsne[:, 1])
print("tsne[:,2]", tsne[:, 2])
x, y, z = tsne[:, 0], tsne[:, 1], tsne[:, 2]
x = (x - np.min(x)) / (np.max(x) - np.min(x))
y = (y - np.min(y)) / (np.max(y) - np.min(y))
z = (z - np.min(z)) / (np.max(z) - np.min(z))
showResult(step_two_model, test_data=all_data, isPlot=True)
fig1 = plt.figure()
ax1 = fig1.add_subplot(projection='3d')
ax1.scatter3D(x, y, z, c=train_label, cmap=plt.cm.get_cmap("jet", 10))
fig2 = plt.figure()
ax2 = fig2.add_subplot()
sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
# fig3 = plt.figure()
# ax3 = fig3.add_subplot()
# plt.plot(history.epoch, history.history.get('acc'), label='acc')
# plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
# trained_data = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('bn_last').output).predict(
# train_data)
# predict_label = model.predict(test_data)
# predict_label_max = np.argmax(predict_label, axis=1)
# predict_label = np.expand_dims(predict_label_max, axis=1)
#
# fig4 = plt.figure()
# ax4 = fig3.add_subplot()
# plt.plot(history.epoch, history.history.get('loss'), label='loss')
# plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()
plt.show()
score = model.evaluate(test_data, test_label)
print('score:', score)
# confusion_matrix = confusion_matrix(test_label, predict_label)
#
# tsne = TSNE(n_components=3, verbose=2, perplexity=30, n_iter=5000).fit_transform(trained_data)
# print("tsne[:,0]", tsne[:, 0])
# print("tsne[:,1]", tsne[:, 1])
# print("tsne[:,2]", tsne[:, 2])
# x, y, z = tsne[:, 0], tsne[:, 1], tsne[:, 2]
# x = (x - np.min(x)) / (np.max(x) - np.min(x))
# y = (y - np.min(y)) / (np.max(y) - np.min(y))
# z = (z - np.min(z)) / (np.max(z) - np.min(z))
#
# fig1 = plt.figure()
# ax1 = fig1.add_subplot(projection='3d')
# ax1.scatter3D(x, y, z, c=train_label, cmap=plt.cm.get_cmap("jet", 10))
#
# fig2 = plt.figure()
# ax2 = fig2.add_subplot()
# sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
# plt.ylabel('Actual label')
# plt.xlabel('Predicted label')
#
# # fig3 = plt.figure()
# # ax3 = fig3.add_subplot()
# # plt.plot(history.epoch, history.history.get('acc'), label='acc')
# # plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
# #
# # fig4 = plt.figure()
# # ax4 = fig3.add_subplot()
# # plt.plot(history.epoch, history.history.get('loss'), label='loss')
# # plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
# plt.legend()
# plt.show()
#
# score = model.evaluate(test_data, test_label)
# print('score:', score)

View File

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# coding: utf-8
import matplotlib.pyplot as plt
import numpy as np
'''
@Author : dingjiawen
@Date : 2022/10/20 21:35
@Usage :
@Desc :
'''
def plot_result(result_data):
parameters = {
'figure.dpi': 600,
'figure.figsize': (2.5, 2),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 5,
'ytick.labelsize': 5,
'legend.fontsize': 5,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001)
# 画出 y=1 这条水平线
plt.axhline(0.5, c='red', label='Failure threshold')
# 箭头指向上面的水平线
# plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
# alpha=0.9, overhang=0.5)
plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=5, color='red',
verticalalignment='top')
plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.')
# plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
plt.text(result_data.shape[0] * 4 / 5, 0.6, "Fault", fontsize=5, color='black', verticalalignment='top',
horizontalalignment='center',
bbox={'facecolor': 'grey',
'pad': 2.5}, fontdict=font1)
plt.text(result_data.shape[0] * 1 / 3, 0.4, "Norm", fontsize=5, color='black', verticalalignment='top',
horizontalalignment='center',
bbox={'facecolor': 'grey',
'pad': 2.5}, fontdict=font1)
indices = [result_data.shape[0] *i / 4 for i in range(5)]
print(indices)
classes = ['N', 'IF', 'OF','TRC', 'oo'] # for i in range(cls):
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices. classes,rotation=45)#设置横坐标方向rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=5)
plt.xlabel('Predicted label', fontsize=5)
plt.tight_layout()
plt.show()
pass
if __name__ == '__main__':
file_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_C\RNet_C_timestamp120_feature10_result.csv"
# result_data = np.recfromcsv(file_name)
result_data = np.loadtxt(file_name, delimiter=",")
result_data = np.array(result_data)
data = np.zeros([208, ])
result_data = np.concatenate([result_data, data], axis=0)
print(result_data)
print(result_data.shape)
plot_result(result_data)