模型更新
This commit is contained in:
parent
2400af4a1a
commit
c3e7572691
|
|
@ -27,8 +27,8 @@ batch_size = 32
|
|||
EPOCH = 1000
|
||||
unit = 512 # LSTM的维度
|
||||
predict_num = 50 # 预测个数
|
||||
model_name = "LSTM"
|
||||
save_name = r"selfMulti_norm_{0}_hidden{1}_unit{2}_feature{3}_predict{4}.h5".format(model_name, hidden_num, unit,
|
||||
model_name = "FTLSTM"
|
||||
save_name = r"selfMulti_{0}_hidden{1}_unit{2}_feature{3}_predict{4}.h5".format(model_name, hidden_num, unit,
|
||||
feature,
|
||||
predict_num)
|
||||
|
||||
|
|
@ -237,7 +237,7 @@ if __name__ == '__main__':
|
|||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=20, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.SGD(), loss=tf.losses.mse)
|
||||
model.compile(optimizer=tf.optimizers.SGD(), loss=FTMSE())
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=100, mode='min', verbose=1)
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ predict_data.shape: (total_dims - filter_num, filter_num) :(571,600,30)
|
|||
train_label.shape: (total_dims - filter_num - 1, filter_num) :(570,600)
|
||||
'''
|
||||
|
||||
|
||||
def remove(train_data, train_label, batch_size):
|
||||
epoch, _, _ = train_data.shape
|
||||
size = int(epoch / batch_size)
|
||||
|
|
@ -98,6 +99,8 @@ train_data.shape: (1230, 10, 10)
|
|||
train_label.shape: (1230, 10)
|
||||
train_label_single.shape: (1230,)
|
||||
'''
|
||||
|
||||
|
||||
def splitValData(data, label, label_single, predict_num=50):
|
||||
sample, hidden, feature = data.shape
|
||||
|
||||
|
|
@ -182,27 +185,27 @@ if __name__ == '__main__':
|
|||
train_label,
|
||||
train_label_single,
|
||||
predict_num=predict_num)
|
||||
# # #### TODO 训练
|
||||
model = predict_model(hidden_num, feature)
|
||||
checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=save_name,
|
||||
monitor='val_loss',
|
||||
verbose=2,
|
||||
save_best_only=True,
|
||||
mode='min')
|
||||
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=20, min_lr=0.001)
|
||||
|
||||
model.compile(optimizer=tf.optimizers.SGD(), loss=tf.losses.mse)
|
||||
model.summary()
|
||||
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=100, mode='min', verbose=1)
|
||||
|
||||
history = model.fit(train_data, train_label_single, epochs=EPOCH,
|
||||
batch_size=batch_size, validation_data=(val_data, val_label_single), shuffle=True, verbose=2,
|
||||
callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
# # # #### TODO 训练
|
||||
# model = predict_model(hidden_num, feature)
|
||||
# checkpoint = tf.keras.callbacks.ModelCheckpoint(
|
||||
# filepath=save_name,
|
||||
# monitor='val_loss',
|
||||
# verbose=2,
|
||||
# save_best_only=True,
|
||||
# mode='min')
|
||||
# lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=20, min_lr=0.001)
|
||||
#
|
||||
# model.compile(optimizer=tf.optimizers.SGD(), loss=tf.losses.mse)
|
||||
# model.summary()
|
||||
# early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=100, mode='min', verbose=1)
|
||||
#
|
||||
# history = model.fit(train_data, train_label_single, epochs=EPOCH,
|
||||
# batch_size=batch_size, validation_data=(val_data, val_label_single), shuffle=True, verbose=2,
|
||||
# callbacks=[checkpoint, lr_scheduler, early_stop])
|
||||
|
||||
#### TODO 测试
|
||||
|
||||
trained_model = tf.keras.models.load_model(save_name, custom_objects={'LSTMLayer': LSTMLayer})
|
||||
trained_model = tf.keras.models.load_model(save_name, custom_objects={'LSTMLayer': LSTMLayer, 'FTMSE': FTMSE})
|
||||
|
||||
# 使用已知的点进行预测
|
||||
predicted_data = predictByEveryData(trained_model, train_data)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@ class FTMSE(tf.keras.losses.Loss):
|
|||
yt_fft = tf.signal.fft(tf.cast(y_true, tf.complex64))
|
||||
yp_fft = tf.signal.fft(tf.cast(y_pred, tf.complex64))
|
||||
|
||||
epoch, length, _ = yp_fft.shape
|
||||
print("yt_amp",yt_fft)
|
||||
print("yp_fft",yp_fft)
|
||||
|
||||
epoch, length = yp_fft.shape
|
||||
# 幅值
|
||||
yt_amp = tf.abs(yt_fft / length)
|
||||
yp_amp = tf.abs(yp_fft / length)
|
||||
|
|
@ -28,9 +31,16 @@ class FTMSE(tf.keras.losses.Loss):
|
|||
yt_angle = tf.math.angle(yt_fft)
|
||||
yp_angle = tf.math.angle(yp_fft)
|
||||
|
||||
|
||||
|
||||
time_loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
|
||||
amp_loss = tf.keras.losses.mean_squared_error(yt_amp, yp_amp)
|
||||
angle_loss = tf.keras.losses.mean_squared_error(yt_angle, yp_angle)
|
||||
print(time_loss)
|
||||
print(amp_loss)
|
||||
print(angle_loss)
|
||||
ftLoss = time_loss + amp_loss
|
||||
# ftLoss = time_loss + amp_loss + angle_loss
|
||||
# ftLoss = time_loss
|
||||
|
||||
ftLoss = time_loss + amp_loss + angle_loss
|
||||
return ftLoss
|
||||
|
|
|
|||
Loading…
Reference in New Issue