self_example/pytorch_example/RUL/baseModel/plot.py

114 lines
3.7 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/20 19:30
@Usage :
@Desc :
'''
import matplotlib.pyplot as plt
import time
def plot_prediction(total_data, predicted_data_easy, predicted_data_hard, save_fig_name, predict_num=50):
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
timestamp = str(int(time.time()))[-4:] # 画图的时间戳取后四位
plt.figure(1)
'''保存的模型参数的路径'''
from matplotlib import rcParams
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
# 简单预测图
length = len(predicted_data_easy)
plt.scatter(list(range(length)), total_data, c='blue', s=12, label='Actual value')
plt.plot(list(range(length - predict_num)),
predicted_data_easy[:length - predict_num], linewidth=2, color='red',
label='Traning value')
plt.scatter(list(range(length - predict_num, length)), predicted_data_easy[length - predict_num:length], c='black',
s=15, label='Predictive value')
plt.axhline(total_data[-1], linewidth=2, c='green', label='Failure threshold')
# plt.title()
plt.xlabel('Serial number of the fusion feature point', font=font1)
plt.ylabel('Virtual health indicator', font=font1)
plt.legend(loc='upper left', prop=font2)
plt.savefig(save_fig_name + 'easy{0}.png'.format(timestamp))
plt.show()
# 困难预测图
plt.figure(2)
'''保存的模型参数的路径'''
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
# 简单预测图
length = len(predicted_data_easy)
plt.scatter(list(range(length)), total_data, c='blue', s=12, label='Actual value')
plt.plot(list(range(length - predict_num)),
predicted_data_hard[:length - predict_num], linewidth=2, color='red',
label='Traning value')
plt.scatter(list(range(length - predict_num, length)), predicted_data_hard[length - predict_num:length], c='black',
s=15, label='Predictive value')
# plt.title()
plt.xlabel('Serial number of the fusion feature point', font=font1)
plt.ylabel('Virtual health indicator', font=font1)
plt.axhline(total_data[-1], linewidth=2, c='green', label='Failure threshold')
plt.legend(loc='upper left', prop=font2)
plt.savefig(save_fig_name + 'hard{0}.png'.format(timestamp))
plt.show()
def plot_forSelf(total_data, predicted_data_easy, predicted_data_hard):
pic1 = plt.figure(figsize=(8, 6), dpi=200)
'''保存的模型参数的路径'''
from matplotlib import rcParams
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
# 简单预测图
plt.subplot(2, 1, 1)
plt.plot(total_data)
plt.plot(predicted_data_easy)
plt.title('Easy Prediction')
plt.xlabel('time')
plt.ylabel('loss')
# plt.legend(loc='upper right')
# 困难预测图
plt.subplot(2, 1, 2)
plt.plot(total_data)
plt.plot(predicted_data_hard)
plt.title('Easy Prediction')
plt.xlabel('time')
plt.ylabel('loss')
# plt.legend(loc='upper right')
# plt.scatter()
plt.show()