# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/6/14 15:28 @Usage : @Desc : 一些画图方法 ''' from sklearn.metrics import mean_absolute_error, mean_squared_error from pylab import * # 图像上显示中文 mpl.rcParams['font.sans-serif'] = ['SimHei'] # 调整使图像支持负号 mpl.rcParams["axes.unicode_minus"] = False font1 = { 'family': 'Times New Roman', 'weight': 'normal', 'size': 12, } font2 = { 'family': 'Times New Roman', 'weight': 'normal', 'size': 15, } def calScore(y_test, pred): # TODO 打印误差 test_mse = round(mean_squared_error(y_test, pred), 4) test_rmse = round(math.sqrt(mean_squared_error(y_test, pred)), 4) # mape 暂时这样 test_mape = round(mean_absolute_error(pred, y_test) * 100, 4) test_mae = round(mean_absolute_error(pred, y_test), 4) # TODO 计算得分 result = list(np.squeeze(pred, 1)) exceed = list(filter(lambda res: res >= y_test[-1], result)) print("len(exceed)", len(exceed)) if len(exceed) > 0: exceed_index = result.index(exceed[0]) print("len(result)", len(result)) # Eri = round((((2750 - (len(result) - exceed_index)) - 2750) / 2750), 4) Eri = round((exceed_index / len(result)), 4) print("RUL", 100 - (len(result) - exceed_index)) print("Eri", Eri) if Eri <= 0: score = round(math.exp(-math.log(0.5, e) * (Eri / 5)), 4) else: score = round(math.exp(math.log(0.5, e) * (Eri / 20)), 4) print("score1", score) score = exceed_index / len(result) else: Eri = nan score = nan print('MSE_testScore: %.4f MSE' % test_mse) print('RMSE_testScore: %.4f RMSE' % test_rmse) print('MAE_testScore: %.4f MAE' % test_mae) print('MAPE_testScore: %.4f MAPE' % test_mape) print("score: %.4f score" % score) pass # 画图 def getPlot(data, feature, time_step, x_train, x_test, pred, truePred, train_pred, saveName="../store/test"): train_pred = np.squeeze(train_pred, 1) print("train_pred", train_pred) # TODO 实际值 # 设置xtick和ytick的方向:in、out、inout plt.rcParams['xtick.direction'] = 'in' plt.rcParams['ytick.direction'] = 'in' plt.plot(list(range(data.shape[0])), data) # 画出 y=1 这条水平线 plt.axhline(data[-1], c='green') plt.grid() plt.ylim(0, 1) plt.xlim(-50, 1300) # TODO 真实预测散点图 # TODO 图2 plt.figure(2) point_len = x_train.shape[0] + feature + time_step - 1 # plt.figure(2, figsize=(12, 4)) # 设置xtick和ytick的方向:in、out、inout plt.rcParams['xtick.direction'] = 'in' plt.rcParams['ytick.direction'] = 'in' print("pred", pred[:, -1]) print("truePred", truePred[:, -1]) figname2 = saveName + "single.png" plt.scatter(list(range(data.shape[0])), data, c='blue', s=12, label='Actual value') # # TODO 这里要改成Training value 10(重叠丢失) + 9(转置) +1141(训练数据已知) + 9(转置) = 1169 + 81 (预测数据) =1250 # # 训练数据传入模型预测一次即为训练数据 plt.plot(list(range(time_step + feature - 1, point_len)), train_pred, linewidth=2, color='red', label='Training value') plt.scatter(list(range(point_len, point_len + x_test.shape[0])), pred, c='black', s=15, label='Predictive value') # 画出 y=1 这条水平线 plt.axhline(data[-1], linewidth=2, c='green', label='Failure threshold') plt.ylim(-0.2, 0.95) plt.xlim(-50, 1300) plt.xlabel("Serial number of the fusion feature point", font=font2) plt.ylabel("Virtual health indicator", font=font2) plt.legend(loc='upper left', prop=font1) plt.savefig(figname2, )