114 lines
3.7 KiB
Python
114 lines
3.7 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/20 19:30
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
|
|
|
|
def plot_prediction(total_data, predicted_data_easy, predicted_data_hard, save_fig_name, predict_num=50):
|
|
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
timestamp = str(int(time.time()))[-4:] # 画图的时间戳取后四位
|
|
plt.figure(1)
|
|
|
|
'''保存的模型参数的路径'''
|
|
|
|
from matplotlib import rcParams
|
|
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
|
|
# 简单预测图
|
|
length = len(predicted_data_easy)
|
|
plt.scatter(list(range(length)), total_data, c='blue', s=12, label='Actual value')
|
|
plt.plot(list(range(length - predict_num)),
|
|
predicted_data_easy[:length - predict_num], linewidth=2, color='red',
|
|
label='Traning value')
|
|
plt.scatter(list(range(length - predict_num, length)), predicted_data_easy[length - predict_num:length], c='black',
|
|
s=15, label='Predictive value')
|
|
plt.axhline(total_data[-1], linewidth=2, c='green', label='Failure threshold')
|
|
# plt.title()
|
|
plt.xlabel('Serial number of the fusion feature point', font=font1)
|
|
plt.ylabel('Virtual health indicator', font=font1)
|
|
|
|
plt.legend(loc='upper left', prop=font2)
|
|
plt.savefig(save_fig_name + 'easy{0}.png'.format(timestamp))
|
|
plt.show()
|
|
|
|
# 困难预测图
|
|
plt.figure(2)
|
|
|
|
'''保存的模型参数的路径'''
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
|
|
# 简单预测图
|
|
length = len(predicted_data_easy)
|
|
|
|
plt.scatter(list(range(length)), total_data, c='blue', s=12, label='Actual value')
|
|
plt.plot(list(range(length - predict_num)),
|
|
predicted_data_hard[:length - predict_num], linewidth=2, color='red',
|
|
label='Traning value')
|
|
plt.scatter(list(range(length - predict_num, length)), predicted_data_hard[length - predict_num:length], c='black',
|
|
s=15, label='Predictive value')
|
|
# plt.title()
|
|
plt.xlabel('Serial number of the fusion feature point', font=font1)
|
|
plt.ylabel('Virtual health indicator', font=font1)
|
|
plt.axhline(total_data[-1], linewidth=2, c='green', label='Failure threshold')
|
|
plt.legend(loc='upper left', prop=font2)
|
|
|
|
plt.savefig(save_fig_name + 'hard{0}.png'.format(timestamp))
|
|
plt.show()
|
|
|
|
|
|
def plot_forSelf(total_data, predicted_data_easy, predicted_data_hard):
|
|
pic1 = plt.figure(figsize=(8, 6), dpi=200)
|
|
|
|
'''保存的模型参数的路径'''
|
|
|
|
from matplotlib import rcParams
|
|
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
|
|
# 简单预测图
|
|
plt.subplot(2, 1, 1)
|
|
plt.plot(total_data)
|
|
plt.plot(predicted_data_easy)
|
|
plt.title('Easy Prediction')
|
|
plt.xlabel('time')
|
|
plt.ylabel('loss')
|
|
# plt.legend(loc='upper right')
|
|
|
|
# 困难预测图
|
|
plt.subplot(2, 1, 2)
|
|
plt.plot(total_data)
|
|
plt.plot(predicted_data_hard)
|
|
plt.title('Easy Prediction')
|
|
plt.xlabel('time')
|
|
plt.ylabel('loss')
|
|
|
|
# plt.legend(loc='upper right')
|
|
|
|
# plt.scatter()
|
|
|
|
plt.show()
|