# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/20 19:30 @Usage : @Desc : ''' import matplotlib.pyplot as plt import time def plot_prediction(total_data, predicted_data_easy, predicted_data_hard, save_fig_name, predict_num=50): font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 timestamp = str(int(time.time()))[-4:] # 画图的时间戳取后四位 plt.figure(1) '''保存的模型参数的路径''' from matplotlib import rcParams config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) # 简单预测图 length = len(predicted_data_easy) plt.scatter(list(range(length)), total_data, c='blue', s=12, label='Actual value') plt.plot(list(range(length - predict_num)), predicted_data_easy[:length - predict_num], linewidth=2, color='red', label='Traning value') plt.scatter(list(range(length - predict_num, length)), predicted_data_easy[length - predict_num:length], c='black', s=15, label='Predictive value') plt.axhline(total_data[-1], linewidth=2, c='green', label='Failure threshold') # plt.title() plt.xlabel('Serial number of the fusion feature point', font=font1) plt.ylabel('Virtual health indicator', font=font1) plt.legend(loc='upper left', prop=font2) plt.savefig(save_fig_name + 'easy{0}.png'.format(timestamp)) plt.show() # 困难预测图 plt.figure(2) '''保存的模型参数的路径''' config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) # 简单预测图 length = len(predicted_data_easy) plt.scatter(list(range(length)), total_data, c='blue', s=12, label='Actual value') plt.plot(list(range(length - predict_num)), predicted_data_hard[:length - predict_num], linewidth=2, color='red', label='Traning value') plt.scatter(list(range(length - predict_num, length)), predicted_data_hard[length - predict_num:length], c='black', s=15, label='Predictive value') # plt.title() plt.xlabel('Serial number of the fusion feature point', font=font1) plt.ylabel('Virtual health indicator', font=font1) plt.axhline(total_data[-1], linewidth=2, c='green', label='Failure threshold') plt.legend(loc='upper left', prop=font2) plt.savefig(save_fig_name + 'hard{0}.png'.format(timestamp)) plt.show() def plot_forSelf(total_data, predicted_data_easy, predicted_data_hard): pic1 = plt.figure(figsize=(8, 6), dpi=200) '''保存的模型参数的路径''' from matplotlib import rcParams config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) # 简单预测图 plt.subplot(2, 1, 1) plt.plot(total_data) plt.plot(predicted_data_easy) plt.title('Easy Prediction') plt.xlabel('time') plt.ylabel('loss') # plt.legend(loc='upper right') # 困难预测图 plt.subplot(2, 1, 2) plt.plot(total_data) plt.plot(predicted_data_hard) plt.title('Easy Prediction') plt.xlabel('time') plt.ylabel('loss') # plt.legend(loc='upper right') # plt.scatter() plt.show()