self_example/pytorch_example/temp/MSETest.py

131 lines
4.1 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/12/27 16:52
@Usage :
@Desc : 测试MSE三点不足
'''
import numpy as np
import matplotlib.pyplot as plt
import torch_dct as dct
def plot(x, y):
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
plt.figure(1)
y[-1] = 0.00832
'''保存的模型参数的路径'''
from matplotlib import rcParams
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
predict1 = y[-1] + 0.2
predict2 = y[-1] - 0.2
# 简单预测图
plt.scatter(x, y, c='blue', s=16, label='Actual value')
plt.plot(x, y, linewidth=2, color='red',
label='Traning value')
plt.scatter([x[-1]], [predict1], c='black',
s=18, label='Predict data1')
plt.scatter([x[-1]], [predict2], c='black',
s=18, label='Predict data2')
plt.xlabel('Serial number of the fusion feature point', font=font1)
plt.ylabel('Virtual health indicator', font=font1)
plt.legend(loc='upper left', prop=font2)
plt.show()
a = np.hstack([y[:-1], predict1])
b = np.hstack([y[:-1], predict2])
return y, a, b
def fft(x, a, b):
from matplotlib import rcParams
# 幅值
amp_y0 = np.abs(x / len(x))
amp_y1 = np.abs(a / len(a))
amp_y2 = np.abs(b / len(b))
# 相角
angle_y0 = np.angle(x)
angle_y1 = np.angle(a)
angle_y2 = np.angle(b)
plt.figure(2)
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
plt.plot(amp_y1, linewidth=2, color='green',
label='Predict data1')
plt.plot(amp_y2, linewidth=2, color='red',
label='Predict data2')
plt.plot(amp_y0, linewidth=2, color='blue',
label='Original data')
plt.legend(loc='upper left', prop=font2)
plt.xlabel('Serial number of the fusion feature point', font=font1)
plt.ylabel('Amplitude', font=font1)
plt.show()
plt.figure(3)
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
plt.plot(angle_y1, linewidth=2, color='green',
label='Predict data1')
plt.plot(angle_y2, linewidth=2, color='red',
label='Predict data2')
plt.plot(angle_y0, linewidth=2, color='blue',
label='Original data')
plt.legend(loc='upper left', prop=font2)
plt.xlabel('Serial number of the fusion feature point', font=font1)
plt.ylabel('Angle', font=font1)
plt.show()
pass
length = 30
# y = np.array([-0.029078494757,
# -0.33095228672,
# -0.12124221772,
# 0.553512275219,
# -0.158036053181,
# 0.268739402294,
# -0.638222515583,
# 0.233140587807,
# -0.173265621066,
# 0.467218101025,
# -0.372010827065,
# -0.136630430818,
# 0.343256533146,
# 0.008932195604])
y = np.random.random([length, ])
x = list(range(0, len(y)))
print(y)
y, a, b = plot(x, y)
fft(y, a, b)