diff --git a/Big_data_example/user_profile/export-clickhouse/pom.xml b/Big_data_example/user_profile/export-clickhouse/pom.xml
index 9252da0..836a83f 100644
--- a/Big_data_example/user_profile/export-clickhouse/pom.xml
+++ b/Big_data_example/user_profile/export-clickhouse/pom.xml
@@ -17,6 +17,19 @@
com.atguigu
1.0-SNAPSHOT
+
+
+ junit
+ junit
+ 4.13.2
+ test
+
+
+ junit
+ junit
+ 4.13.2
+ compile
+
diff --git a/Big_data_example/user_profile/export-clickhouse/src/main/resources/config.properties b/Big_data_example/user_profile/export-clickhouse/src/main/resources/config.properties
index d76795a..72654b4 100644
--- a/Big_data_example/user_profile/export-clickhouse/src/main/resources/config.properties
+++ b/Big_data_example/user_profile/export-clickhouse/src/main/resources/config.properties
@@ -8,4 +8,5 @@ mysql.username=root
mysql.password=123456
# clickhouse
-clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
+#clickhouse.url=jdbc:clickhouse://Ding202:8123/user_profile0224
+clickhouse.url=jdbc:clickhouse://localhost:8123/default
diff --git a/Big_data_example/user_profile/export-clickhouse/src/test/java/clickhouseJDBC.scala b/Big_data_example/user_profile/export-clickhouse/src/test/java/clickhouseJDBC.scala
new file mode 100644
index 0000000..e720336
--- /dev/null
+++ b/Big_data_example/user_profile/export-clickhouse/src/test/java/clickhouseJDBC.scala
@@ -0,0 +1,44 @@
+import com.atguigu.userprofile.common.utils.ClickhouseUtils
+
+object clickhouseJDBC {
+
+ def main(args: Array[String]): Unit = {
+
+ val createSQL =
+ s"""
+ |create table t_order_mt(
+ | uid UInt32,
+ | sku_id String,
+ | total_amount Decimal(16,2),
+ | create_time Datetime
+ | ) engine =MergeTree
+ | partition by toYYYYMMDD(create_time)
+ | primary key (uid)
+ | order by (uid,sku_id)
+ |""".stripMargin
+
+ ClickhouseUtils.executeSql(createSQL)
+
+ val insertSQL=
+ """
+ |insert into t_order_mt
+ |values(101,'sku_001',1000.00,'2020-06-01 12:00:00') ,
+ |(102,'sku_002',2000.00,'2020-06-01 11:00:00'),
+ |(102,'sku_004',2500.00,'2020-06-01 12:00:00'),
+ |(102,'sku_002',2000.00,'2020-06-01 13:00:00')
+ |(102,'sku_002',12000.00,'2020-06-01 13:00:00')
+ |(102,'sku_002',600.00,'2020-06-02 12:00:00')
+ |""".stripMargin
+
+ ClickhouseUtils.executeSql(insertSQL)
+
+
+ val selectSQL=
+ """
+ |select * from t_order_mt
+ |""".stripMargin
+
+ ClickhouseUtils.executeSql(selectSQL)
+ }
+
+}
diff --git a/Big_data_example/user_profile/export-clickhouse/src/test/java/clickhouseJDNC.java b/Big_data_example/user_profile/export-clickhouse/src/test/java/clickhouseJDNC.java
new file mode 100644
index 0000000..6bf8330
--- /dev/null
+++ b/Big_data_example/user_profile/export-clickhouse/src/test/java/clickhouseJDNC.java
@@ -0,0 +1,20 @@
+import com.atguigu.userprofile.common.utils.ClickhouseUtils;
+import org.junit.Test;
+
+/**
+ * @BelongsProject: user-profile-manager0111
+ * @BelongsPackage: PACKAGE_NAME
+ * @Author: markilue
+ * @CreateTime: 2022-10-20 16:58
+ * @Description: TODO 尝试连接clickhouse
+ * @Version: 1.0
+ */
+public class clickhouseJDNC {
+
+
+ @Test
+
+ public void testClickhouse(){
+
+ }
+}
diff --git a/Leecode/src/main/java/com/markilue/leecode/backtrace/SolveSudoku.java b/Leecode/src/main/java/com/markilue/leecode/backtrace/SolveSudoku.java
new file mode 100644
index 0000000..0a81234
--- /dev/null
+++ b/Leecode/src/main/java/com/markilue/leecode/backtrace/SolveSudoku.java
@@ -0,0 +1,285 @@
+package com.markilue.leecode.backtrace;
+
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * @BelongsProject: Leecode
+ * @BelongsPackage: com.markilue.leecode.backtrace
+ * @Author: markilue
+ * @CreateTime: 2022-10-21 10:11
+ * @Description: TODO 力扣37题 解数独:
+ * 编写一个程序,通过填充空格来解决数独问题。
+ * 数独的解法需 遵循如下规则:
+ * 数字 1-9 在每一行只能出现一次。
+ * 数字 1-9 在每一列只能出现一次。
+ * 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。(请参考示例图)
+ * 数独部分空格内已填入了数字,空白格用 '.' 表示。
+ * @Version: 1.0
+ */
+public class SolveSudoku {
+
+ @Test
+ public void test() {
+ char a = '1';
+ char[][] board = {
+ {'5', '3', '.', '.', '7', '.', '.', '.', '.'},
+ {'6', '.', '.', '1', '9', '5', '.', '.', '.'},
+ {'.', '9', '8', '.', '.', '.', '.', '6', '.'},
+ {'8', '.', '.', '.', '6', '.', '.', '.', '3'},
+ {'4', '.', '.', '8', '.', '3', '.', '.', '1'},
+ {'7', '.', '.', '.', '2', '.', '.', '.', '6'},
+ {'.', '6', '.', '.', '.', '.', '2', '8', '.'},
+ {'.', '.', '.', '4', '1', '9', '.', '.', '5'},
+ {'.', '.', '.', '.', '8', '.', '.', '7', '9'}};
+
+ solveSudoku(board);
+ for (int i = 0; i < board.length; i++) {
+ System.out.println(Arrays.toString(board[i]));
+ }
+
+ }
+
+ public void solveSudoku(char[][] board) {
+ backtracking1(board);
+ }
+
+ /**
+ * 自己思路:填写i,j位置的值
+ * 存在问题,似乎存在问题,出错之后会一直return回溯了?
+ * @param board
+ * @param i
+ * @param j
+ */
+ public void backtracking(char[][] board, int i, int j) {
+
+ if (i == board.length) {
+ return;
+ }
+
+ if (j == board.length) {
+ backtracking(board, i + 1, 0);
+ return;
+ }
+
+ if(board[i][j]!='.'){
+ backtracking(board,i,j+1);
+ return;
+ }
+
+ for (int k = 1; k <= 9; k++) {
+
+ if (!check(board, i, j, k)) {
+ continue;
+ }
+
+ board[i][j] = (char) ('0' + k);
+ backtracking(board, i, j + 1);
+// board[i][j] = '.';
+
+ }
+
+
+ }
+
+
+ /**
+ * 代码随想录思路:
+ * 速度击败6.41%,内存击败65.83%
+ * @param board
+ * @return
+ */
+ public boolean backtracking1(char[][] board) {
+
+ for (int i = 0; i < board.length; i++) {
+ for (int j = 0; j < board[0].length; j++) {
+ if(board[i][j]!='.')continue;
+ for (char k = '1'; k <= '9'; k++) {
+ if(check1(board,i,j,k)){
+ board[i][j]=k;
+ //找到第一个直接return
+ if(backtracking1(board))return true;
+ board[i][j]='.'; //回溯撤销k
+ }
+
+ }
+ return false;
+ }
+
+ }
+ return true; //遍历完了都没有返回,那么就说明找到了合适的位置
+
+
+ }
+
+ /**
+ * i,j分别是需要填写的横坐标和纵坐标,value是要填入的值
+ *
+ * @param board
+ * @param i
+ * @param j
+ * @return
+ */
+ public boolean check1(char[][] board, int i, int j, char value) {
+
+// //测试当前位置
+// if (board[i][j] != '.') {
+// return false;
+// }
+
+ //监测横轴
+ for (int k = 0; k < board[i].length; k++) {
+ if (board[i][k] == value) {
+ return false;
+ }
+ }
+ //监测纵轴
+ for (int k = 0; k < board[j].length; k++) {
+ if ( board[k][j] == value) {
+ return false;
+ }
+ }
+
+ //监测九宫格
+ int a = (i / 3)*3;
+ int b = (j / 3)*3;
+
+
+ for (int k = a; k < a+2; k++) {
+ for (int l = b; l <= b+2; l++) {
+ if ( board[k][l] == value) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+
+ }
+
+ /**
+ * i,j分别是需要填写的横坐标和纵坐标,value是要填入的值
+ *
+ * @param board
+ * @param i
+ * @param j
+ * @return
+ */
+ public boolean check(char[][] board, int i, int j, int value) {
+
+// //测试当前位置
+// if (board[i][j] != '.') {
+// return false;
+// }
+
+ //监测横轴
+ for (int k = 0; k < board[i].length; k++) {
+ if (board[i][k] != '.' && board[i][k] - '0' == value) {
+ return false;
+ }
+ }
+ //监测纵轴
+ for (int k = 0; k < board[j].length; k++) {
+ if (board[k][j] != '.' && board[k][j] - '0' == value) {
+ return false;
+ }
+ }
+
+ //监测九宫格
+ int a = i % 3;
+ int b = j % 3;
+ int xstart = 0;
+ int xend = 0;
+ int ystart = 0;
+ int yend = 0;
+ if (a == 0) {
+ xstart = i + 1;
+ xend = i + 2;
+ } else if (a == 1) {
+ xstart = i - 1;
+ xend = i + 1;
+ } else {
+ xstart = i - 2;
+ xend = i - 1;
+ }
+
+ if (b == 0) {
+ ystart = j + 1;
+ yend = j + 2;
+ } else if (b == 1) {
+ ystart = j - 1;
+ yend = j + 1;
+ } else {
+ ystart = j - 2;
+ yend = j - 1;
+ }
+
+
+ for (int k = xstart; k <= xend; k++) {
+ for (int l = ystart; l <= yend; l++) {
+ if (board[k][l] != '.' && board[k][l] - '0' == value) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+
+ }
+
+
+ /**
+ * 官方回溯法:本质上就是记录各个位置使用的数组的情况,然后根据条件:
+ * line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = false才能填入
+ * 需要填写的位置通过spaces记录下来,然后按个去填写
+ * 官方还有两种优化方式:位运算方式和枚举优化,但是都不是算法层面的优化,此处不一一写
+ * 位运算优化:分别是使用int[i]=k数组去记录位置i填入的数字k
+ * 枚举优化:如果一个空白格只有唯一的数可以填入,也就是其对应的 bb 值和 b-1b−1 进行按位与运算后得到 00(即 bb 中只有一个二进制位为 11)。此时,我们就可以确定这个空白格填入的数,而不用等到递归时再去处理它。
+ *
+ */
+ private boolean[][] line = new boolean[9][9];
+ private boolean[][] column = new boolean[9][9];
+ private boolean[][][] block = new boolean[3][3][9];
+ private boolean valid = false;
+ private List spaces = new ArrayList();
+
+ public void solveSudoku1(char[][] board) {
+ for (int i = 0; i < 9; ++i) {
+ for (int j = 0; j < 9; ++j) {
+ if (board[i][j] == '.') {
+ spaces.add(new int[]{i, j});
+ } else {
+ int digit = board[i][j] - '0' - 1;
+ line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = true;
+ }
+ }
+ }
+
+ dfs(board, 0);
+ }
+
+ public void dfs(char[][] board, int pos) {
+ if (pos == spaces.size()) {
+ valid = true;
+ return;
+ }
+
+ int[] space = spaces.get(pos);
+ int i = space[0], j = space[1];
+ for (int digit = 0; digit < 9 && !valid; ++digit) {
+ if (!line[i][digit] && !column[j][digit] && !block[i / 3][j / 3][digit]) {
+ line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = true;
+ board[i][j] = (char) (digit + '0' + 1);
+ dfs(board, pos + 1);
+ line[i][digit] = column[j][digit] = block[i / 3][j / 3][digit] = false;
+ }
+ }
+ }
+
+
+
+
+}
diff --git a/TensorFlow_eaxmple/Model_train_test/condition_monitoring/self_try/compare/resnet_18.py b/TensorFlow_eaxmple/Model_train_test/condition_monitoring/self_try/compare/resnet_18.py
index 8e196f1..399720f 100644
--- a/TensorFlow_eaxmple/Model_train_test/condition_monitoring/self_try/compare/resnet_18.py
+++ b/TensorFlow_eaxmple/Model_train_test/condition_monitoring/self_try/compare/resnet_18.py
@@ -21,23 +21,23 @@ K = 18
namuda = 0.01
'''保存名称'''
-save_name = "./model//{0}.h5".format(model_name,
- time_stamp,
- feature_num,
- batch_size,
- EPOCH)
+save_name = "./model/{0}.h5".format(model_name,
+ time_stamp,
+ feature_num,
+ batch_size,
+ EPOCH)
save_step_two_name = "./model/two_weight/{0}_weight_epoch6_99899_9996/weight".format(model_name,
time_stamp,
feature_num,
batch_size,
EPOCH)
-save_mse_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
+save_mse_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_result.csv".format(model_name,
time_stamp,
feature_num,
batch_size,
EPOCH)
-save_max_name = "./mse/RNet_C/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
+save_max_name = "./mse/ResNet/{0}_timestamp{1}_feature{2}_max.csv".format(model_name,
time_stamp,
feature_num,
batch_size,
@@ -220,11 +220,134 @@ def resnet_Model():
dense = tf.keras.layers.Dense(128, activation=tf.nn.relu)(dropout)
dense = tf.keras.layers.BatchNormalization(name="bn_last")(dense)
- dense = tf.keras.layers.Dense(2, activation=tf.nn.sigmoid)(dense)
+ dense = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(dense)
model = tf.keras.Model(inputs=inputs, outputs=dense)
return model
+def showResult(step_two_model: tf.keras.Model, test_data, isPlot: bool = False, isSave: bool = True):
+ # 获取模型的所有参数的个数
+ # step_two_model.count_params()
+ total_result = []
+
+ size, length, dims = test_data.shape
+ predict_label = step_two_model.predict(test_data)
+ total_result = np.reshape(total_result, [total_result.__len__(), -1])
+ total_result = np.reshape(total_result, [-1, ])
+
+ # 误报率,漏报率,准确性的计算
+
+ if isSave:
+ np.savetxt(save_mse_name, total_result, delimiter=',')
+ if isPlot:
+ plt.figure(1, figsize=(6.0, 2.68))
+ plt.subplots_adjust(left=0.1, right=0.94, bottom=0.2, top=0.9, wspace=None,
+ hspace=None)
+ plt.tight_layout()
+ font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 10} # 设置坐标标签的字体大小,字体
+
+ plt.scatter(list(range(total_result.shape[0])), total_result, c='black', s=10)
+ # 画出 y=1 这条水平线
+ plt.axhline(0.5, c='red', label='Failure threshold')
+ # 箭头指向上面的水平线
+ # plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
+ # alpha=0.9, overhang=0.5)
+ plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
+ verticalalignment='top')
+ plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
+ plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
+ plt.tick_params() # 设置轴显示
+ plt.xlabel("time", fontdict=font1)
+ plt.ylabel("confience", fontdict=font1)
+ plt.text(total_result.shape[0] * 4 / 5, 0.6, "Fault", fontsize=10, color='black', verticalalignment='top',
+ horizontalalignment='center',
+ bbox={'facecolor': 'grey',
+ 'pad': 10}, fontdict=font1)
+ plt.text(total_result.shape[0] * 1 / 3, 0.4, "Norm", fontsize=10, color='black', verticalalignment='top',
+ horizontalalignment='center',
+ bbox={'facecolor': 'grey',
+ 'pad': 10}, fontdict=font1)
+
+ plt.grid()
+ # plt.ylim(0, 1)
+ # plt.xlim(-50, 1300)
+ # plt.legend("", loc='upper left')
+ plt.show()
+ return total_result
+
+
+def plot_result(result_data):
+ parameters = {
+ 'figure.dpi': 600,
+ 'figure.figsize': (5, 5),
+ 'savefig.dpi': 600,
+ 'xtick.direction': 'in',
+ 'ytick.direction': 'in',
+ 'xtick.labelsize': 20,
+ 'ytick.labelsize': 20,
+ 'legend.fontsize': 11.3,
+ }
+ plt.rcParams.update(parameters)
+ plt.figure()
+ plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
+
+ plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=10)
+ # 画出 y=1 这条水平线
+ plt.axhline(0.5, c='red', label='Failure threshold')
+ # 箭头指向上面的水平线
+ # plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
+ # alpha=0.9, overhang=0.5)
+ plt.text(test_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=10, color='red',
+ verticalalignment='top')
+ plt.axvline(test_data.shape[0] * 2 / 3, c='blue', ls='-.')
+ plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
+
+
+ indices = range(result_data.shape[0])
+ classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
+ # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
+ # plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
+ plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
+ # plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
+ plt.ylabel('Actual label', fontsize=20)
+ plt.xlabel('Predicted label', fontsize=20)
+ plt.tight_layout()
+ plt.show()
+ pass
+
+
+def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
+ sns.set(font_scale=1.5)
+ parameters = {
+ 'figure.dpi': 600,
+ 'figure.figsize': (5, 5),
+ 'savefig.dpi': 600,
+ 'xtick.direction': 'in',
+ 'ytick.direction': 'in',
+ 'xtick.labelsize': 20,
+ 'ytick.labelsize': 20,
+ 'legend.fontsize': 11.3,
+ }
+ plt.rcParams.update(parameters)
+ plt.figure()
+ plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
+ confusion = confusion_matrix(true_labels, predict_labels)
+ # confusion = confusion.astype('float') / confusion.sum(axis=1)[: np.newaxis] plt.figure()
+ # sns.heatmap(confusion, annot=True,fmt="d",cmap="Greens")
+ sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
+ indices = range(len(confusion))
+ classes = ['N', 'IF', 'OF' 'TRC', 'TSP'] # for i in range(cls):
+ # classes.append(str(i))
+ # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
+ # plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
+ plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
+ plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
+ plt.ylabel('Actual label', fontsize=20)
+ plt.xlabel('Predicted label', fontsize=20)
+ plt.tight_layout()
+ plt.show()
+
+
if __name__ == '__main__':
# # 数据读入
#
@@ -251,60 +374,68 @@ if __name__ == '__main__':
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.binary_crossentropy,
metrics=['acc'])
model.summary()
- early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, mode='min', verbose=1)
-
- checkpoint = tf.keras.callbacks.ModelCheckpoint(
- filepath=save_name,
- monitor='val_loss',
- verbose=1,
- save_best_only=True,
- mode='min',
- period=1)
-
- history = model.fit(train_data, train_label, epochs=20, batch_size=32, validation_data=(test_data, test_label),
- callbacks=[checkpoint, early_stop])
- model.save("./model/ResNet.h5")
- model = tf.keras.models.load_model("../model/ResNet_model.h5")
+ # early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5, mode='min', verbose=1)
+ #
+ # checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ # filepath=save_name,
+ # monitor='val_loss',
+ # verbose=1,
+ # save_best_only=False,
+ # mode='min',
+ # period=1)
+ #
+ # history = model.fit(train_data, train_label, epochs=20, batch_size=16, validation_data=(test_data, test_label),
+ # callbacks=[checkpoint, early_stop])
+ # model.save("./model/ResNet.h5")
+ model = tf.keras.models.load_model("./model/ResNet_2_9906.h5")
# 结果展示
- trained_data = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('bn_last').output).predict(
- train_data)
- predict_label = model.predict(test_data)
- predict_label_max = np.argmax(predict_label, axis=1)
- predict_label = np.expand_dims(predict_label_max, axis=1)
- confusion_matrix = confusion_matrix(test_label, predict_label)
+ healthy_size, _, _ = train_data_healthy.shape
+ unhealthy_size, _, _ = train_data_unhealthy.shape
+ all_data, _, _ = get_training_data_overlapping(
+ total_data[healthy_size - 2 * unhealthy_size:unhealthy_date, :], is_Healthy=True)
- tsne = TSNE(n_components=3, verbose=2, perplexity=30, n_iter=5000).fit_transform(trained_data)
- print("tsne[:,0]", tsne[:, 0])
- print("tsne[:,1]", tsne[:, 1])
- print("tsne[:,2]", tsne[:, 2])
- x, y, z = tsne[:, 0], tsne[:, 1], tsne[:, 2]
- x = (x - np.min(x)) / (np.max(x) - np.min(x))
- y = (y - np.min(y)) / (np.max(y) - np.min(y))
- z = (z - np.min(z)) / (np.max(z) - np.min(z))
+ showResult(step_two_model, test_data=all_data, isPlot=True)
- fig1 = plt.figure()
- ax1 = fig1.add_subplot(projection='3d')
- ax1.scatter3D(x, y, z, c=train_label, cmap=plt.cm.get_cmap("jet", 10))
-
- fig2 = plt.figure()
- ax2 = fig2.add_subplot()
- sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
- plt.ylabel('Actual label')
- plt.xlabel('Predicted label')
-
- # fig3 = plt.figure()
- # ax3 = fig3.add_subplot()
- # plt.plot(history.epoch, history.history.get('acc'), label='acc')
- # plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
+ # trained_data = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('bn_last').output).predict(
+ # train_data)
+ # predict_label = model.predict(test_data)
+ # predict_label_max = np.argmax(predict_label, axis=1)
+ # predict_label = np.expand_dims(predict_label_max, axis=1)
#
- # fig4 = plt.figure()
- # ax4 = fig3.add_subplot()
- # plt.plot(history.epoch, history.history.get('loss'), label='loss')
- # plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
- plt.legend()
- plt.show()
-
- score = model.evaluate(test_data, test_label)
- print('score:', score)
+ # confusion_matrix = confusion_matrix(test_label, predict_label)
+ #
+ # tsne = TSNE(n_components=3, verbose=2, perplexity=30, n_iter=5000).fit_transform(trained_data)
+ # print("tsne[:,0]", tsne[:, 0])
+ # print("tsne[:,1]", tsne[:, 1])
+ # print("tsne[:,2]", tsne[:, 2])
+ # x, y, z = tsne[:, 0], tsne[:, 1], tsne[:, 2]
+ # x = (x - np.min(x)) / (np.max(x) - np.min(x))
+ # y = (y - np.min(y)) / (np.max(y) - np.min(y))
+ # z = (z - np.min(z)) / (np.max(z) - np.min(z))
+ #
+ # fig1 = plt.figure()
+ # ax1 = fig1.add_subplot(projection='3d')
+ # ax1.scatter3D(x, y, z, c=train_label, cmap=plt.cm.get_cmap("jet", 10))
+ #
+ # fig2 = plt.figure()
+ # ax2 = fig2.add_subplot()
+ # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap='Blues')
+ # plt.ylabel('Actual label')
+ # plt.xlabel('Predicted label')
+ #
+ # # fig3 = plt.figure()
+ # # ax3 = fig3.add_subplot()
+ # # plt.plot(history.epoch, history.history.get('acc'), label='acc')
+ # # plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
+ # #
+ # # fig4 = plt.figure()
+ # # ax4 = fig3.add_subplot()
+ # # plt.plot(history.epoch, history.history.get('loss'), label='loss')
+ # # plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
+ # plt.legend()
+ # plt.show()
+ #
+ # score = model.evaluate(test_data, test_label)
+ # print('score:', score)
diff --git a/TensorFlow_eaxmple/Model_train_test/condition_monitoring/test.py b/TensorFlow_eaxmple/Model_train_test/condition_monitoring/test.py
new file mode 100644
index 0000000..f6cf365
--- /dev/null
+++ b/TensorFlow_eaxmple/Model_train_test/condition_monitoring/test.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+
+# coding: utf-8
+import matplotlib.pyplot as plt
+import numpy as np
+
+'''
+@Author : dingjiawen
+@Date : 2022/10/20 21:35
+@Usage :
+@Desc :
+'''
+
+
+def plot_result(result_data):
+ parameters = {
+ 'figure.dpi': 600,
+ 'figure.figsize': (2.5, 2),
+ 'savefig.dpi': 600,
+ 'xtick.direction': 'in',
+ 'ytick.direction': 'in',
+ 'xtick.labelsize': 5,
+ 'ytick.labelsize': 5,
+ 'legend.fontsize': 5,
+ }
+ plt.rcParams.update(parameters)
+ plt.figure()
+ plt.rc('font', family='Times New Roman') # 全局字体样式#画混淆矩阵
+ font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 5} # 设置坐标标签的字体大小,字体
+ plt.scatter(list(range(result_data.shape[0])), result_data, c='black', s=0.001)
+ # 画出 y=1 这条水平线
+ plt.axhline(0.5, c='red', label='Failure threshold')
+ # 箭头指向上面的水平线
+ # plt.arrow(35000, 0.9, 33000, 0.75, head_width=0.02, head_length=0.1, shape="full", fc='red', ec='red',
+ # alpha=0.9, overhang=0.5)
+ plt.text(result_data.shape[0] * 2 / 3 + 1000, 0.7, "Truth Fault", fontsize=5, color='red',
+ verticalalignment='top')
+ plt.axvline(result_data.shape[0] * 2 / 3, c='blue', ls='-.')
+ # plt.xticks(range(6), ('06/09/17', '12/09/17', '18/09/17', '24/09/17', '29/09/17')) # 设置x轴的标尺
+ plt.text(result_data.shape[0] * 4 / 5, 0.6, "Fault", fontsize=5, color='black', verticalalignment='top',
+ horizontalalignment='center',
+ bbox={'facecolor': 'grey',
+ 'pad': 2.5}, fontdict=font1)
+ plt.text(result_data.shape[0] * 1 / 3, 0.4, "Norm", fontsize=5, color='black', verticalalignment='top',
+ horizontalalignment='center',
+ bbox={'facecolor': 'grey',
+ 'pad': 2.5}, fontdict=font1)
+
+
+ indices = [result_data.shape[0] *i / 4 for i in range(5)]
+ print(indices)
+ classes = ['N', 'IF', 'OF','TRC', 'oo'] # for i in range(cls):
+ # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
+ # plt.xticks(indices. classes,rotation=45)#设置横坐标方向,rotation=45为45度倾斜# plt.yticks(indices, classes, rotation=45)
+ plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
+ # plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
+ plt.ylabel('Actual label', fontsize=5)
+ plt.xlabel('Predicted label', fontsize=5)
+ plt.tight_layout()
+ plt.show()
+ pass
+
+
+if __name__ == '__main__':
+ file_name = "E:\self_example\TensorFlow_eaxmple\Model_train_test\condition_monitoring\self_try\compare\mse\RNet_C\RNet_C_timestamp120_feature10_result.csv"
+ # result_data = np.recfromcsv(file_name)
+ result_data = np.loadtxt(file_name, delimiter=",")
+ result_data = np.array(result_data)
+ data = np.zeros([208, ])
+ result_data = np.concatenate([result_data, data], axis=0)
+ print(result_data)
+ print(result_data.shape)
+ plot_result(result_data)