131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/12/27 16:52
|
|
@Usage :
|
|
@Desc : 测试MSE三点不足
|
|
'''
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import torch_dct as dct
|
|
|
|
|
|
def plot(x, y):
|
|
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
plt.figure(1)
|
|
y[-1] = 0.00832
|
|
|
|
'''保存的模型参数的路径'''
|
|
|
|
from matplotlib import rcParams
|
|
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
predict1 = y[-1] + 0.2
|
|
predict2 = y[-1] - 0.2
|
|
|
|
# 简单预测图
|
|
plt.scatter(x, y, c='blue', s=16, label='Actual value')
|
|
plt.plot(x, y, linewidth=2, color='red',
|
|
label='Traning value')
|
|
plt.scatter([x[-1]], [predict1], c='black',
|
|
s=18, label='Predict data1')
|
|
plt.scatter([x[-1]], [predict2], c='black',
|
|
s=18, label='Predict data2')
|
|
plt.xlabel('Serial number of the fusion feature point', font=font1)
|
|
plt.ylabel('Virtual health indicator', font=font1)
|
|
plt.legend(loc='upper left', prop=font2)
|
|
|
|
plt.show()
|
|
|
|
a = np.hstack([y[:-1], predict1])
|
|
|
|
b = np.hstack([y[:-1], predict2])
|
|
return y, a, b
|
|
|
|
|
|
def fft(x, a, b):
|
|
from matplotlib import rcParams
|
|
|
|
# 幅值
|
|
amp_y0 = np.abs(x / len(x))
|
|
amp_y1 = np.abs(a / len(a))
|
|
amp_y2 = np.abs(b / len(b))
|
|
# 相角
|
|
angle_y0 = np.angle(x)
|
|
angle_y1 = np.angle(a)
|
|
angle_y2 = np.angle(b)
|
|
|
|
plt.figure(2)
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
|
|
plt.plot(amp_y1, linewidth=2, color='green',
|
|
label='Predict data1')
|
|
plt.plot(amp_y2, linewidth=2, color='red',
|
|
label='Predict data2')
|
|
plt.plot(amp_y0, linewidth=2, color='blue',
|
|
label='Original data')
|
|
plt.legend(loc='upper left', prop=font2)
|
|
plt.xlabel('Serial number of the fusion feature point', font=font1)
|
|
plt.ylabel('Amplitude', font=font1)
|
|
plt.show()
|
|
|
|
plt.figure(3)
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
font1 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
font2 = {'family': 'Times New Roman', 'weight': 'normal', 'size': 15} # 设置坐标标签的字体大小,字体
|
|
|
|
plt.plot(angle_y1, linewidth=2, color='green',
|
|
label='Predict data1')
|
|
plt.plot(angle_y2, linewidth=2, color='red',
|
|
label='Predict data2')
|
|
plt.plot(angle_y0, linewidth=2, color='blue',
|
|
label='Original data')
|
|
plt.legend(loc='upper left', prop=font2)
|
|
plt.xlabel('Serial number of the fusion feature point', font=font1)
|
|
plt.ylabel('Angle', font=font1)
|
|
plt.show()
|
|
pass
|
|
|
|
|
|
length = 30
|
|
|
|
# y = np.array([-0.029078494757,
|
|
# -0.33095228672,
|
|
# -0.12124221772,
|
|
# 0.553512275219,
|
|
# -0.158036053181,
|
|
# 0.268739402294,
|
|
# -0.638222515583,
|
|
# 0.233140587807,
|
|
# -0.173265621066,
|
|
# 0.467218101025,
|
|
# -0.372010827065,
|
|
# -0.136630430818,
|
|
# 0.343256533146,
|
|
# 0.008932195604])
|
|
|
|
y = np.random.random([length, ])
|
|
x = list(range(0, len(y)))
|
|
print(y)
|
|
y, a, b = plot(x, y)
|
|
fft(y, a, b)
|