117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
# -*- encoding:utf-8 -*-
|
||
|
||
'''
|
||
@Author : dingjiawen
|
||
@Date : 2023/6/14 15:28
|
||
@Usage :
|
||
@Desc : 一些画图方法
|
||
'''
|
||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||
from pylab import *
|
||
|
||
# 图像上显示中文
|
||
mpl.rcParams['font.sans-serif'] = ['SimHei']
|
||
# 调整使图像支持负号
|
||
mpl.rcParams["axes.unicode_minus"] = False
|
||
|
||
font1 = {
|
||
'family': 'Times New Roman',
|
||
'weight': 'normal',
|
||
'size': 12,
|
||
}
|
||
|
||
font2 = {
|
||
'family': 'Times New Roman',
|
||
'weight': 'normal',
|
||
'size': 15,
|
||
}
|
||
|
||
|
||
def calScore(y_test, pred):
|
||
# TODO 打印误差
|
||
test_mse = round(mean_squared_error(y_test, pred), 4)
|
||
test_rmse = round(math.sqrt(mean_squared_error(y_test, pred)), 4)
|
||
# mape 暂时这样
|
||
test_mape = round(mean_absolute_error(pred, y_test) * 100, 4)
|
||
test_mae = round(mean_absolute_error(pred, y_test), 4)
|
||
# TODO 计算得分
|
||
result = list(np.squeeze(pred, 1))
|
||
exceed = list(filter(lambda res: res >= y_test[-1], result))
|
||
print("len(exceed)", len(exceed))
|
||
if len(exceed) > 0:
|
||
exceed_index = result.index(exceed[0])
|
||
print("len(result)", len(result))
|
||
# Eri = round((((2750 - (len(result) - exceed_index)) - 2750) / 2750), 4)
|
||
Eri = round((exceed_index / len(result)), 4)
|
||
print("RUL", 100 - (len(result) - exceed_index))
|
||
print("Eri", Eri)
|
||
|
||
if Eri <= 0:
|
||
score = round(math.exp(-math.log(0.5, e) * (Eri / 5)), 4)
|
||
else:
|
||
score = round(math.exp(math.log(0.5, e) * (Eri / 20)), 4)
|
||
print("score1", score)
|
||
score = exceed_index / len(result)
|
||
else:
|
||
Eri = nan
|
||
score = nan
|
||
|
||
print('MSE_testScore: %.4f MSE' % test_mse)
|
||
print('RMSE_testScore: %.4f RMSE' % test_rmse)
|
||
print('MAE_testScore: %.4f MAE' % test_mae)
|
||
print('MAPE_testScore: %.4f MAPE' % test_mape)
|
||
print("score: %.4f score" % score)
|
||
pass
|
||
|
||
|
||
# 画图
|
||
def getPlot(data, feature, time_step, x_train, x_test, pred, truePred, train_pred,
|
||
saveName="../store/test"):
|
||
train_pred = np.squeeze(train_pred, 1)
|
||
print("train_pred", train_pred)
|
||
|
||
# TODO 实际值
|
||
# 设置xtick和ytick的方向:in、out、inout
|
||
plt.rcParams['xtick.direction'] = 'in'
|
||
plt.rcParams['ytick.direction'] = 'in'
|
||
plt.plot(list(range(data.shape[0])), data)
|
||
# 画出 y=1 这条水平线
|
||
plt.axhline(data[-1], c='green')
|
||
plt.grid()
|
||
|
||
plt.ylim(0, 1)
|
||
plt.xlim(-50, 1300)
|
||
|
||
# TODO 真实预测散点图
|
||
|
||
# TODO 图2
|
||
plt.figure(2)
|
||
point_len = x_train.shape[0] + feature + time_step - 1
|
||
|
||
# plt.figure(2, figsize=(12, 4))
|
||
# 设置xtick和ytick的方向:in、out、inout
|
||
plt.rcParams['xtick.direction'] = 'in'
|
||
plt.rcParams['ytick.direction'] = 'in'
|
||
|
||
print("pred", pred[:, -1])
|
||
print("truePred", truePred[:, -1])
|
||
figname2 = saveName + "single.png"
|
||
plt.scatter(list(range(data.shape[0])), data, c='blue', s=12, label='Actual value')
|
||
# # TODO 这里要改成Training value 10(重叠丢失) + 9(转置) +1141(训练数据已知) + 9(转置) = 1169 + 81 (预测数据) =1250
|
||
# # 训练数据传入模型预测一次即为训练数据
|
||
plt.plot(list(range(time_step + feature - 1, point_len)), train_pred, linewidth=2, color='red',
|
||
label='Training value')
|
||
|
||
plt.scatter(list(range(point_len, point_len + x_test.shape[0])), pred, c='black', s=15,
|
||
label='Predictive value')
|
||
# 画出 y=1 这条水平线
|
||
plt.axhline(data[-1], linewidth=2, c='green', label='Failure threshold')
|
||
plt.ylim(-0.2, 0.95)
|
||
plt.xlim(-50, 1300)
|
||
plt.xlabel("Serial number of the fusion feature point", font=font2)
|
||
plt.ylabel("Virtual health indicator", font=font2)
|
||
plt.legend(loc='upper left', prop=font1)
|
||
plt.savefig(figname2, )
|
||
|
||
|