模型更新
This commit is contained in:
parent
9d72cb9592
commit
a35ccc2f63
|
|
@ -14,6 +14,8 @@ import matplotlib.pyplot as plt
|
||||||
from keras.callbacks import EarlyStopping
|
from keras.callbacks import EarlyStopping
|
||||||
|
|
||||||
from model.LossFunction.FTMSE import FTMSE
|
from model.LossFunction.FTMSE import FTMSE
|
||||||
|
from model.ChannelAttention.DCT_channelAttention import DCTChannelAttention
|
||||||
|
from model.ChannelAttention.Light_channelAttention import LightChannelAttention1 as LightChannelAttention
|
||||||
import math
|
import math
|
||||||
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
from sklearn.metrics import mean_absolute_error, mean_squared_error
|
||||||
from pylab import *
|
from pylab import *
|
||||||
|
|
@ -27,10 +29,10 @@ batch_size = 32
|
||||||
EPOCH = 1000
|
EPOCH = 1000
|
||||||
unit = 512 # LSTM的维度
|
unit = 512 # LSTM的维度
|
||||||
predict_num = 50 # 预测个数
|
predict_num = 50 # 预测个数
|
||||||
model_name = "FTLSTM"
|
model_name = "FC_FTLSTM"
|
||||||
save_name = r"selfMulti_{0}_hidden{1}_unit{2}_feature{3}_predict{4}.h5".format(model_name, hidden_num, unit,
|
save_name = r"selfMulti_{0}_hidden{1}_unit{2}_feature{3}_predict{4}.h5".format(model_name, hidden_num, unit,
|
||||||
feature,
|
feature,
|
||||||
predict_num)
|
predict_num)
|
||||||
|
|
||||||
|
|
||||||
def standardization(data):
|
def standardization(data):
|
||||||
|
|
@ -134,6 +136,8 @@ def splitValData(data, label, label_single, predict_num=50):
|
||||||
|
|
||||||
|
|
||||||
def predict_model_multi(filter_num, dims):
|
def predict_model_multi(filter_num, dims):
|
||||||
|
tf.config.experimental_run_functions_eagerly(True)
|
||||||
|
|
||||||
input = tf.keras.Input(shape=[filter_num, dims])
|
input = tf.keras.Input(shape=[filter_num, dims])
|
||||||
input = tf.cast(input, tf.float32)
|
input = tf.cast(input, tf.float32)
|
||||||
|
|
||||||
|
|
@ -143,9 +147,11 @@ def predict_model_multi(filter_num, dims):
|
||||||
|
|
||||||
#### 自己
|
#### 自己
|
||||||
LSTM = LSTMLayer(units=512, return_sequences=True)(input)
|
LSTM = LSTMLayer(units=512, return_sequences=True)(input)
|
||||||
|
# LSTM = LightChannelAttention()(LSTM)
|
||||||
LSTM = LSTMLayer(units=256, return_sequences=True)(LSTM)
|
LSTM = LSTMLayer(units=256, return_sequences=True)(LSTM)
|
||||||
|
LSTM = LightChannelAttention()(LSTM)
|
||||||
|
|
||||||
###flatten
|
### flatten
|
||||||
x = tf.keras.layers.Flatten()(LSTM)
|
x = tf.keras.layers.Flatten()(LSTM)
|
||||||
x = tf.keras.layers.Dense(128, activation="relu")(x)
|
x = tf.keras.layers.Dense(128, activation="relu")(x)
|
||||||
x = tf.keras.layers.Dense(64, activation="relu")(x)
|
x = tf.keras.layers.Dense(64, activation="relu")(x)
|
||||||
|
|
@ -161,30 +167,7 @@ def predict_model_multi(filter_num, dims):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def split_data(train_data, train_label):
|
|
||||||
return train_data[:1150, :, :], train_label[:1150, :], train_data[-70:, :, :], train_label[-70:, :]
|
|
||||||
|
|
||||||
|
|
||||||
# 仅使用预测出来的最新的一个点预测以后
|
# 仅使用预测出来的最新的一个点预测以后
|
||||||
def predictOneByOne(newModel, train_data, predict_num=50):
|
|
||||||
# 取出训练数据的最后一条
|
|
||||||
each_predict_data = np.expand_dims(train_data[-1, :, :], axis=0)
|
|
||||||
predicted_list = np.empty(shape=(predict_num, 1)) # (5,filter_num,30)
|
|
||||||
# all_data = total_data # (1201,)
|
|
||||||
for each_predict in range(predict_num):
|
|
||||||
# predicted_data.shape : (1,1)
|
|
||||||
predicted_data = newModel.predict(each_predict_data) # (batch_size,filer_num,1)
|
|
||||||
predicted_list[each_predict] = predicted_data
|
|
||||||
# (1,1) => (10,1)
|
|
||||||
temp1 = np.transpose(np.concatenate([each_predict_data[:, 1:, -1], predicted_data], axis=1), [1, 0])
|
|
||||||
|
|
||||||
each_predict_data = np.expand_dims(
|
|
||||||
np.concatenate([np.squeeze(each_predict_data[:, :, 1:], axis=0), temp1], axis=1), axis=0)
|
|
||||||
|
|
||||||
return predicted_list
|
|
||||||
|
|
||||||
|
|
||||||
# 使用最后预测出来的一整行与之前的拼接
|
|
||||||
def predictContinueByOne(newModel, train_data, predict_num=50):
|
def predictContinueByOne(newModel, train_data, predict_num=50):
|
||||||
# 取出训练数据的最后一条
|
# 取出训练数据的最后一条
|
||||||
each_predict_data = np.expand_dims(train_data[-1, :, :], axis=0)
|
each_predict_data = np.expand_dims(train_data[-1, :, :], axis=0)
|
||||||
|
|
@ -221,13 +204,13 @@ def predictByEveryData(trained_model: tf.keras.Model, predict_data):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 数据读取
|
# 数据读取
|
||||||
# 数据读入 --> 所有的原始数据;所有的训练数据;所有的训练标签(预测一个序列);所有的训练标签(预测一个点)
|
# 数据读入 --> 所有的原始数据;所有的训练数据;所有的训练标签(预测一个序列);所有的训练标签(预测一个点)
|
||||||
total_data, train_data, train_label, train_label_single = getData(hidden_num, feature, if_norm=False)
|
total_data, train_data, train_label, train_label_single = getData(hidden_num, feature)
|
||||||
# 根据预测的点数划分训练集和测试集(验证集)
|
# 根据预测的点数划分训练集和测试集(验证集)
|
||||||
train_data, val_data, train_label, val_label, train_label_single, val_label_single = splitValData(train_data,
|
train_data, val_data, train_label, val_label, train_label_single, val_label_single = splitValData(train_data,
|
||||||
train_label,
|
train_label,
|
||||||
train_label_single,
|
train_label_single,
|
||||||
predict_num=predict_num)
|
predict_num=predict_num)
|
||||||
# # #### TODO 训练
|
# #### TODO 训练
|
||||||
model = predict_model_multi(hidden_num, feature)
|
model = predict_model_multi(hidden_num, feature)
|
||||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||||
filepath=save_name,
|
filepath=save_name,
|
||||||
|
|
@ -237,7 +220,8 @@ if __name__ == '__main__':
|
||||||
mode='min')
|
mode='min')
|
||||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=20, min_lr=0.001)
|
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=20, min_lr=0.001)
|
||||||
|
|
||||||
model.compile(optimizer=tf.optimizers.SGD(), loss=FTMSE())
|
model.compile(optimizer=tf.optimizers.SGD(), loss=tf.losses.mse)
|
||||||
|
# model.compile(optimizer=tf.optimizers.SGD(learning_rate=0.001), loss=FTMSE())
|
||||||
model.summary()
|
model.summary()
|
||||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=100, mode='min', verbose=1)
|
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=100, mode='min', verbose=1)
|
||||||
|
|
||||||
|
|
@ -247,7 +231,11 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
#### TODO 测试
|
#### TODO 测试
|
||||||
|
|
||||||
trained_model = tf.keras.models.load_model(save_name, custom_objects={'LSTMLayer': LSTMLayer})
|
# trained_model = tf.keras.models.load_model(save_name, custom_objects={'LSTMLayer': LSTMLayer, 'FTMSE': FTMSE})
|
||||||
|
|
||||||
|
# todo 解决自定义loss无法导入的问题
|
||||||
|
trained_model = tf.keras.models.load_model(save_name, compile=False, custom_objects={'LSTMLayer': LSTMLayer,'DCTChannelAttention':DCTChannelAttention})
|
||||||
|
trained_model.compile(optimizer=tf.optimizers.SGD(), loss=FTMSE())
|
||||||
|
|
||||||
# 使用已知的点进行预测
|
# 使用已知的点进行预测
|
||||||
predicted_data = predictByEveryData(trained_model, train_data)
|
predicted_data = predictByEveryData(trained_model, train_data)
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
#### TODO 测试
|
#### TODO 测试
|
||||||
|
|
||||||
trained_model = tf.keras.models.load_model(save_name, custom_objects={'LSTMLayer': LSTMLayer, 'FTMSE': FTMSE})
|
trained_model = tf.keras.models.load_model(save_name, custom_objects={'LSTMLayer': LSTMLayer})
|
||||||
|
|
||||||
# 使用已知的点进行预测
|
# 使用已知的点进行预测
|
||||||
predicted_data = predictByEveryData(trained_model, train_data)
|
predicted_data = predictByEveryData(trained_model, train_data)
|
||||||
|
|
|
||||||
|
|
@ -13,62 +13,139 @@ import tensorflow.keras.layers as layers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from model.DepthwiseCon1D.DepthwiseConv1D import DepthwiseConv1D
|
from tensorflow.keras.layers import Dense, Dropout, ReLU, BatchNormalization
|
||||||
|
from scipy.fftpack import dct
|
||||||
|
|
||||||
|
|
||||||
|
# def dct(x, norm=None):
|
||||||
|
# """
|
||||||
|
# Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
||||||
|
#
|
||||||
|
# For the meaning of the parameter `norm`, see:
|
||||||
|
# https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||||
|
#
|
||||||
|
# :param x: the input signal
|
||||||
|
# :param norm: the normalization, None or 'ortho'
|
||||||
|
# :return: the DCT-II of the signal over the last dimension
|
||||||
|
# """
|
||||||
|
# x_shape = x.shape
|
||||||
|
# N = x_shape[-1]
|
||||||
|
# x = x.contiguous().view(-1, N)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# v = tf.concat([x[:, ::2], x[:, 1::2].flip([1])], axis=1)
|
||||||
|
#
|
||||||
|
# # Vc = torch.fft.rfft(v, 1, onesided=False)
|
||||||
|
# Vc = tf.signal.fft(v, 1)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# k = - tf.range(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
||||||
|
#
|
||||||
|
# W_r = tf.cos(k)
|
||||||
|
# W_i = tf.sin(k)
|
||||||
|
#
|
||||||
|
# V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
||||||
|
#
|
||||||
|
# if norm == 'ortho':
|
||||||
|
# V[:, 0] /= np.sqrt(N) * 2
|
||||||
|
# V[:, 1:] /= np.sqrt(N / 2) * 2
|
||||||
|
#
|
||||||
|
# V = 2 * V.view(*x_shape)
|
||||||
|
#
|
||||||
|
# return V
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def sdct_tf(signals, frame_length, frame_step, window_fn=tf.signal.hamming_window):
|
||||||
|
"""Compute Short-Time Discrete Cosine Transform of `signals`.
|
||||||
|
|
||||||
|
No padding is applied to the signals.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : Time-domain input signal(s), a `[..., n_samples]` tensor.
|
||||||
|
|
||||||
|
frame_length : Window length and DCT frame length in samples.
|
||||||
|
|
||||||
|
frame_step : Number of samples between adjacent DCT columns.
|
||||||
|
|
||||||
|
window_fn : See documentation for `tf.signal.stft`.
|
||||||
|
Default: hamming window. Window to use for DCT.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dct : Real-valued T-F domain DCT matrix/matrixes, a `[..., n_frames, frame_length]` tensor.
|
||||||
|
"""
|
||||||
|
framed = tf.signal.frame(signals, frame_length, frame_step, pad_end=False)
|
||||||
|
if window_fn is not None:
|
||||||
|
window = window_fn(frame_length, dtype=framed.dtype)
|
||||||
|
framed = framed * window[tf.newaxis, :]
|
||||||
|
return tf.signal.dct(framed, norm="ortho", axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def isdct_tf(dcts, *, frame_step, frame_length=None, window_fn=tf.signal.hamming_window):
|
||||||
|
"""Compute Inverse Short-Time Discrete Cosine Transform of `dct`.
|
||||||
|
|
||||||
|
Parameters other than `dcts` are keyword-only.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dcts : DCT matrix/matrices from `sdct_tf`
|
||||||
|
|
||||||
|
frame_step : Number of samples between adjacent DCT columns (should be the
|
||||||
|
same value that was passed to `sdct_tf`).
|
||||||
|
|
||||||
|
frame_length : Ignored. Window length and DCT frame length in samples.
|
||||||
|
Can be None (default) or same value as passed to `sdct_tf`.
|
||||||
|
|
||||||
|
window_fn : See documentation for `tf.signal.istft`.
|
||||||
|
Default: hamming window. Window to use for DCT.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
signals : Time-domain signal(s) reconstructed from `dcts`, a `[..., n_samples]` tensor.
|
||||||
|
Note that `n_samples` may be different from the original signals' lengths as passed to `sdct_torch`,
|
||||||
|
because no padding is applied.
|
||||||
|
"""
|
||||||
|
*_, n_frames, frame_length2 = dcts.shape
|
||||||
|
assert frame_length in {None, frame_length2}
|
||||||
|
signals = tf.signal.overlap_and_add(
|
||||||
|
tf.signal.idct(dcts, norm="ortho", axis=-1), frame_step
|
||||||
|
)
|
||||||
|
if window_fn is not None:
|
||||||
|
window = window_fn(frame_length2, dtype=signals.dtype)
|
||||||
|
window_frames = tf.tile(window[tf.newaxis, :], (n_frames, 1))
|
||||||
|
window_signal = tf.signal.overlap_and_add(window_frames, frame_step)
|
||||||
|
signals = signals / window_signal
|
||||||
|
return signals
|
||||||
|
|
||||||
class DCTChannelAttention(layers.Layer):
|
class DCTChannelAttention(layers.Layer):
|
||||||
def __init__(self):
|
|
||||||
# 调用父类__init__()方法
|
|
||||||
super(DCTChannelAttention, self).__init__()
|
|
||||||
self.DWC = DepthwiseConv1D(kernel_size=1, padding='SAME')
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
if len(input_shape) != 3:
|
_, hidden, channel = input_shape
|
||||||
raise ValueError('Inputs to `DynamicChannelAttention` should have rank 3. '
|
self.l1 = Dense(channel * 2, use_bias=False)
|
||||||
'Received input shape:', str(input_shape))
|
self.drop1 = Dropout(0.1)
|
||||||
|
self.relu = ReLU(0.1)
|
||||||
# print(input_shape)
|
self.l2 = Dense(channel, use_bias=False)
|
||||||
# GAP
|
|
||||||
self.GAP = tf.keras.layers.GlobalAvgPool1D()
|
|
||||||
self.c1 = tf.keras.layers.Conv1D(filters=input_shape[2], kernel_size=1, padding='SAME')
|
|
||||||
# s1 = tf.nn.sigmoid(c1)
|
|
||||||
|
|
||||||
# GMP
|
|
||||||
self.GMP = tf.keras.layers.GlobalMaxPool1D()
|
|
||||||
self.c2 = tf.keras.layers.Conv1D(filters=input_shape[2], kernel_size=1, padding='SAME')
|
|
||||||
# s2 = tf.nn.sigmoid(c2)
|
|
||||||
|
|
||||||
# weight
|
|
||||||
self.weight_kernel = self.add_weight(
|
|
||||||
shape=(1, input_shape[2]),
|
|
||||||
initializer='glorot_uniform',
|
|
||||||
name='weight_kernel')
|
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(self, inputs, **kwargs):
|
||||||
batch_size, length, channel = inputs.shape
|
batch_size, hidden, channel = inputs.shape
|
||||||
# print(batch_size,length,channel)
|
list = []
|
||||||
DWC1 = self.DWC(inputs)
|
stack_dct = tf.signal.dct(inputs, norm="ortho",axis=-1)
|
||||||
|
# for i in range(channel):
|
||||||
|
# freq = tf.signal.dct(inputs[:, i, :], norm="ortho", axis=-1)
|
||||||
|
# # print("freq-shape:",freq.shape)
|
||||||
|
# list.append(freq)
|
||||||
|
# stack_dct = tf.stack(list, dim=1)
|
||||||
|
|
||||||
# GAP
|
lr_weight = self.l1(stack_dct)
|
||||||
GAP = self.GAP(DWC1)
|
lr_weight = self.drop1(lr_weight)
|
||||||
GAP = tf.expand_dims(GAP, axis=1)
|
lr_weight = self.relu(lr_weight)
|
||||||
c1 = self.c1(GAP)
|
lr_weight = self.l2(lr_weight)
|
||||||
c1 = tf.keras.layers.BatchNormalization()(c1)
|
|
||||||
s1 = tf.nn.sigmoid(c1)
|
|
||||||
|
|
||||||
# GMP
|
lr_weight = BatchNormalization()(lr_weight)
|
||||||
GMP = self.GMP(DWC1)
|
|
||||||
GMP = tf.expand_dims(GMP, axis=1)
|
|
||||||
c2 = self.c2(GMP)
|
|
||||||
c2 = tf.keras.layers.BatchNormalization()(c2)
|
|
||||||
s2 = tf.nn.sigmoid(c2)
|
|
||||||
|
|
||||||
# print(self.weight_kernel)
|
return inputs * lr_weight
|
||||||
|
|
||||||
weight_kernel = tf.broadcast_to(self.weight_kernel, shape=[length, channel])
|
|
||||||
weight_kernel = tf.broadcast_to(weight_kernel, shape=[batch_size, length, channel])
|
|
||||||
s1 = tf.broadcast_to(s1, shape=[batch_size, length, channel])
|
|
||||||
s2 = tf.broadcast_to(s2, shape=[batch_size, length, channel])
|
|
||||||
|
|
||||||
output = tf.add(weight_kernel * s1 * inputs, (tf.ones_like(weight_kernel) - weight_kernel) * s2 * inputs)
|
|
||||||
return output
|
|
||||||
|
|
|
||||||
|
|
@ -64,15 +64,6 @@ class LightChannelAttention(layers.Layer):
|
||||||
c1 = tf.keras.layers.BatchNormalization()(c1)
|
c1 = tf.keras.layers.BatchNormalization()(c1)
|
||||||
s1 = tf.nn.sigmoid(c1)
|
s1 = tf.nn.sigmoid(c1)
|
||||||
|
|
||||||
# # GMP
|
|
||||||
# GMP = self.GMP(DWC1)
|
|
||||||
# GMP = tf.expand_dims(GMP, axis=1)
|
|
||||||
# c2 = self.c2(GMP)
|
|
||||||
# c2 = tf.keras.layers.BatchNormalization()(c2)
|
|
||||||
# s2 = tf.nn.sigmoid(c2)
|
|
||||||
|
|
||||||
# print(self.weight_kernel)
|
|
||||||
|
|
||||||
# weight_kernel = tf.broadcast_to(self.weight_kernel, shape=[length, channel])
|
# weight_kernel = tf.broadcast_to(self.weight_kernel, shape=[length, channel])
|
||||||
# weight_kernel = tf.broadcast_to(weight_kernel, shape=[batch_size, length, channel])
|
# weight_kernel = tf.broadcast_to(weight_kernel, shape=[batch_size, length, channel])
|
||||||
s1 = tf.broadcast_to(s1, shape=[batch_size, length, channel])
|
s1 = tf.broadcast_to(s1, shape=[batch_size, length, channel])
|
||||||
|
|
@ -82,6 +73,42 @@ class LightChannelAttention(layers.Layer):
|
||||||
return s1
|
return s1
|
||||||
|
|
||||||
|
|
||||||
|
class LightChannelAttention1(layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 调用父类__init__()方法
|
||||||
|
super(LightChannelAttention1, self).__init__()
|
||||||
|
self.DWC = DepthwiseConv1D(kernel_size=1, padding='SAME')
|
||||||
|
# self.DWC = DepthwiseConv1D(kernel_size=1, padding='causal',dilation_rate=4,data_format='channels_last')
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
if len(input_shape) != 3:
|
||||||
|
raise ValueError('Inputs to `DynamicChannelAttention` should have rank 3. '
|
||||||
|
'Received input shape:', str(input_shape))
|
||||||
|
|
||||||
|
print(input_shape)
|
||||||
|
# GAP
|
||||||
|
self.GAP = tf.keras.layers.GlobalAvgPool1D()
|
||||||
|
self.c1 = tf.keras.layers.Conv1D(filters=input_shape[2], kernel_size=1, padding='SAME')
|
||||||
|
|
||||||
|
|
||||||
|
def call(self, inputs, **kwargs):
|
||||||
|
batch_size, length, channel = inputs.shape
|
||||||
|
DWC1 = self.DWC(inputs)
|
||||||
|
|
||||||
|
# GAP
|
||||||
|
GAP = self.GAP(DWC1)
|
||||||
|
GAP = tf.expand_dims(GAP, axis=1)
|
||||||
|
c1 = self.c1(GAP)
|
||||||
|
c1 = tf.keras.layers.BatchNormalization()(c1)
|
||||||
|
s1 = tf.nn.sigmoid(c1)
|
||||||
|
print(s1)
|
||||||
|
|
||||||
|
s1 = tf.broadcast_to(s1, [batch_size, length, channel])
|
||||||
|
|
||||||
|
|
||||||
|
return s1 * inputs
|
||||||
|
|
||||||
class DynamicPooling(layers.Layer):
|
class DynamicPooling(layers.Layer):
|
||||||
|
|
||||||
def __init__(self, pool_size=2):
|
def __init__(self, pool_size=2):
|
||||||
|
|
|
||||||
|
|
@ -12,35 +12,42 @@ import tensorflow.keras.backend as K
|
||||||
|
|
||||||
|
|
||||||
class FTMSE(tf.keras.losses.Loss):
|
class FTMSE(tf.keras.losses.Loss):
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
y_true = tf.cast(y_true, tf.float32)
|
# y_true = tf.cast(y_true, tf.float32)
|
||||||
y_pred = tf.cast(y_pred, tf.float32)
|
# y_pred = tf.cast(y_pred, tf.float32)
|
||||||
|
|
||||||
|
# tf.print(y_true)
|
||||||
|
# tf.print(y_pred)
|
||||||
|
|
||||||
# 需要转为复数形式
|
# 需要转为复数形式
|
||||||
yt_fft = tf.signal.fft(tf.cast(y_true, tf.complex64))
|
_, length = y_pred.shape
|
||||||
yp_fft = tf.signal.fft(tf.cast(y_pred, tf.complex64))
|
|
||||||
|
|
||||||
print("yt_amp",yt_fft)
|
# 打印精确的实部和虚部
|
||||||
print("yp_fft",yp_fft)
|
yt_fft = tf.signal.fft(tf.complex(y_true, tf.zeros_like(y_true)))
|
||||||
|
yp_fft = tf.signal.fft(tf.complex(y_pred, tf.zeros_like(y_pred)))
|
||||||
|
|
||||||
epoch, length = yp_fft.shape
|
|
||||||
# 幅值
|
# 幅值
|
||||||
yt_amp = tf.abs(yt_fft / length)
|
yt_amp = tf.abs(yt_fft/length)
|
||||||
yp_amp = tf.abs(yp_fft / length)
|
yp_amp = tf.abs(yp_fft/length)
|
||||||
|
# yt_amp = tf.abs(yt_fft)
|
||||||
|
# yp_amp = tf.abs(yp_fft)
|
||||||
# 相角
|
# 相角
|
||||||
yt_angle = tf.math.angle(yt_fft)
|
yt_angle = tf.math.angle(yt_fft)
|
||||||
yp_angle = tf.math.angle(yp_fft)
|
yp_angle = tf.math.angle(yp_fft)
|
||||||
|
|
||||||
|
# tf.print("yt_amp",yt_amp)
|
||||||
|
# tf.print("yp_amp",yp_amp)
|
||||||
|
# tf.print("yt_angle",yt_angle)
|
||||||
|
# tf.print("yp_angle",yp_angle)
|
||||||
|
|
||||||
|
time_loss = K.mean(tf.keras.losses.mean_squared_error(y_true, y_pred),axis=-1)
|
||||||
|
amp_loss = K.mean(tf.keras.losses.mean_squared_error(yt_amp, yp_amp),axis=-1)
|
||||||
|
angle_loss = K.mean(tf.keras.losses.mean_squared_error(yt_angle, yp_angle),axis=-1)
|
||||||
|
tf.print("time_loss:", time_loss, "amp_loss", amp_loss, "angle_loss", angle_loss)
|
||||||
|
|
||||||
time_loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
|
ftLoss = time_loss + amp_loss*5
|
||||||
amp_loss = tf.keras.losses.mean_squared_error(yt_amp, yp_amp)
|
# ftLoss = time_loss + 5 * amp_loss + 0.25 * angle_loss
|
||||||
angle_loss = tf.keras.losses.mean_squared_error(yt_angle, yp_angle)
|
|
||||||
print(time_loss)
|
|
||||||
print(amp_loss)
|
|
||||||
print(angle_loss)
|
|
||||||
ftLoss = time_loss + amp_loss
|
|
||||||
# ftLoss = time_loss + amp_loss + angle_loss
|
|
||||||
# ftLoss = time_loss
|
# ftLoss = time_loss
|
||||||
|
#
|
||||||
return ftLoss
|
return ftLoss
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue