self_example/TensorFlow_eaxmple/Model_train_test/RUL/ResultShowUtils.py

117 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/6/14 15:28
@Usage :
@Desc : 一些画图方法
'''
from sklearn.metrics import mean_absolute_error, mean_squared_error
from pylab import *
# 图像上显示中文
mpl.rcParams['font.sans-serif'] = ['SimHei']
# 调整使图像支持负号
mpl.rcParams["axes.unicode_minus"] = False
font1 = {
'family': 'Times New Roman',
'weight': 'normal',
'size': 12,
}
font2 = {
'family': 'Times New Roman',
'weight': 'normal',
'size': 15,
}
def calScore(y_test, pred):
# TODO 打印误差
test_mse = round(mean_squared_error(y_test, pred), 4)
test_rmse = round(math.sqrt(mean_squared_error(y_test, pred)), 4)
# mape 暂时这样
test_mape = round(mean_absolute_error(pred, y_test) * 100, 4)
test_mae = round(mean_absolute_error(pred, y_test), 4)
# TODO 计算得分
result = list(np.squeeze(pred, 1))
exceed = list(filter(lambda res: res >= y_test[-1], result))
print("len(exceed)", len(exceed))
if len(exceed) > 0:
exceed_index = result.index(exceed[0])
print("len(result)", len(result))
# Eri = round((((2750 - (len(result) - exceed_index)) - 2750) / 2750), 4)
Eri = round((exceed_index / len(result)), 4)
print("RUL", 100 - (len(result) - exceed_index))
print("Eri", Eri)
if Eri <= 0:
score = round(math.exp(-math.log(0.5, e) * (Eri / 5)), 4)
else:
score = round(math.exp(math.log(0.5, e) * (Eri / 20)), 4)
print("score1", score)
score = exceed_index / len(result)
else:
Eri = nan
score = nan
print('MSE_testScore: %.4f MSE' % test_mse)
print('RMSE_testScore: %.4f RMSE' % test_rmse)
print('MAE_testScore: %.4f MAE' % test_mae)
print('MAPE_testScore: %.4f MAPE' % test_mape)
print("score: %.4f score" % score)
pass
# 画图
def getPlot(data, feature, time_step, x_train, x_test, pred, truePred, train_pred,
saveName="../store/test"):
train_pred = np.squeeze(train_pred, 1)
print("train_pred", train_pred)
# TODO 实际值
# 设置xtick和ytick的方向in、out、inout
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.plot(list(range(data.shape[0])), data)
# 画出 y=1 这条水平线
plt.axhline(data[-1], c='green')
plt.grid()
plt.ylim(0, 1)
plt.xlim(-50, 1300)
# TODO 真实预测散点图
# TODO 图2
plt.figure(2)
point_len = x_train.shape[0] + feature + time_step - 1
# plt.figure(2, figsize=(12, 4))
# 设置xtick和ytick的方向in、out、inout
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
print("pred", pred[:, -1])
print("truePred", truePred[:, -1])
figname2 = saveName + "single.png"
plt.scatter(list(range(data.shape[0])), data, c='blue', s=12, label='Actual value')
# # TODO 这里要改成Training value 10(重叠丢失) + 9(转置) +1141(训练数据已知) + 9(转置) = 1169 + 81 (预测数据) =1250
# # 训练数据传入模型预测一次即为训练数据
plt.plot(list(range(time_step + feature - 1, point_len)), train_pred, linewidth=2, color='red',
label='Training value')
plt.scatter(list(range(point_len, point_len + x_test.shape[0])), pred, c='black', s=15,
label='Predictive value')
# 画出 y=1 这条水平线
plt.axhline(data[-1], linewidth=2, c='green', label='Failure threshold')
plt.ylim(-0.2, 0.95)
plt.xlim(-50, 1300)
plt.xlabel("Serial number of the fusion feature point", font=font2)
plt.ylabel("Virtual health indicator", font=font2)
plt.legend(loc='upper left', prop=font1)
plt.savefig(figname2, )