# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/12/27 16:52 @Usage : @Desc : 测试MSE三点不足 ''' import numpy as np import matplotlib.pyplot as plt import torch_dct as dct def plot(x, y): font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 plt.figure(1) y[-1] = 0.00832 '''保存的模型参数的路径''' from matplotlib import rcParams config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) predict1 = y[-1] + 0.2 predict2 = y[-1] - 0.2 # 简单预测图 plt.scatter(x, y, c='blue', s=16, label='Actual value') plt.plot(x, y, linewidth=2, color='red', label='Traning value') plt.scatter([x[-1]], [predict1], c='black', s=18, label='Predict data1') plt.scatter([x[-1]], [predict2], c='black', s=18, label='Predict data2') plt.xlabel('Serial number of the fusion feature point', font=font1) plt.ylabel('Virtual health indicator', font=font1) plt.legend(loc='upper left', prop=font2) plt.show() a = np.hstack([y[:-1], predict1]) b = np.hstack([y[:-1], predict2]) return y, a, b def fft(x, a, b): from matplotlib import rcParams # 幅值 amp_y0 = np.abs(x / len(x)) amp_y1 = np.abs(a / len(a)) amp_y2 = np.abs(b / len(b)) # 相角 angle_y0 = np.angle(x) angle_y1 = np.angle(a) angle_y2 = np.angle(b) plt.figure(2) config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 plt.plot(amp_y1, linewidth=2, color='green', label='Predict data1') plt.plot(amp_y2, linewidth=2, color='red', label='Predict data2') plt.plot(amp_y0, linewidth=2, color='blue', label='Original data') plt.legend(loc='upper left', prop=font2) plt.xlabel('Serial number of the fusion feature point', font=font1) plt.ylabel('Amplitude', font=font1) plt.show() plt.figure(3) config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体 plt.plot(angle_y1, linewidth=2, color='green', label='Predict data1') plt.plot(angle_y2, linewidth=2, color='red', label='Predict data2') plt.plot(angle_y0, linewidth=2, color='blue', label='Original data') plt.legend(loc='upper left', prop=font2) plt.xlabel('Serial number of the fusion feature point', font=font1) plt.ylabel('Angle', font=font1) plt.show() pass length = 30 # y = np.array([-0.029078494757, # -0.33095228672, # -0.12124221772, # 0.553512275219, # -0.158036053181, # 0.268739402294, # -0.638222515583, # 0.233140587807, # -0.173265621066, # 0.467218101025, # -0.372010827065, # -0.136630430818, # 0.343256533146, # 0.008932195604]) y = np.random.random([length, ]) x = list(range(0, len(y))) print(y) y, a, b = plot(x, y) fft(y, a, b)