178 lines
5.5 KiB
Python
178 lines
5.5 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/10 14:56
|
|
@Usage :
|
|
@Desc : 训练LSTM
|
|
'''
|
|
import os
|
|
import time
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import matplotlib.pyplot as plt
|
|
from torch.utils.data import DataLoader
|
|
from RUL.otherIdea.LSTM.modelForEasy import PredictModel
|
|
from RUL.otherIdea.LSTM.loadData import getDataset
|
|
from RUL.otherIdea.LSTM.test import test
|
|
from RUL.baseModel.CommonFunction import IsStopTraining
|
|
from scipy.spatial.distance import cdist
|
|
import math
|
|
import RUL.baseModel.utils.utils as utils
|
|
|
|
'''
|
|
超参数设置:
|
|
'''
|
|
hidden_num = 10 # LSTM细胞个数
|
|
feature = 10 # 一个点的维度
|
|
batch_size = 32
|
|
EPOCH = 1000
|
|
seed = 5
|
|
predict_num = 200 # 预测个数
|
|
is_norm = True
|
|
is_single = True
|
|
model_name = "LSTM"
|
|
base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}".format(model_name, hidden_num, feature,
|
|
predict_num)
|
|
save_fig_name = 'fig/seed{0}_hidden{1}_feature{2}_predict{3}'.format(seed, hidden_num, feature, predict_num)
|
|
|
|
if not os.path.exists("parameters"):
|
|
os.makedirs("parameters")
|
|
if not os.path.exists("fig"):
|
|
os.makedirs("fig")
|
|
|
|
|
|
def get_dataset():
|
|
'''得到数据集'''
|
|
train_dataset, val_dataset = getDataset(
|
|
hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm)
|
|
'''DataLoader'''
|
|
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
|
|
return train_loader, val_loader
|
|
|
|
|
|
def train(device, lr, lr_patience, early_stop_patience, epochs):
|
|
'''预测模型'''
|
|
global best_save_path
|
|
model = PredictModel(input_dim=feature)
|
|
'''得到数据集'''
|
|
train_loader, val_loader = get_dataset()
|
|
criterion = nn.MSELoss().to(device)
|
|
|
|
optimizer_model = torch.optim.SGD(model.parameters(), lr=lr)
|
|
|
|
scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5,
|
|
patience=lr_patience)
|
|
|
|
def zero_grad_all():
|
|
optimizer_model.zero_grad()
|
|
|
|
train_loss_list = []
|
|
val_loss_list = []
|
|
best_save_path = None
|
|
|
|
for epoch in range(epochs):
|
|
epoch_start_time = time.time()
|
|
train_loss = 0.0
|
|
val_loss = 0.0
|
|
|
|
model.train()
|
|
|
|
for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader):
|
|
|
|
train_data, train_label = train_data.float().to(device), train_label.float().to(device)
|
|
|
|
zero_grad_all()
|
|
|
|
predict_data = torch.squeeze(model(train_data))
|
|
|
|
# MSE损失
|
|
loss = criterion(predict_data, train_label)
|
|
|
|
loss.backward()
|
|
optimizer_model.step()
|
|
|
|
zero_grad_all()
|
|
|
|
train_loss += loss.item()
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for val_batch_idx, (val_data, val_label) in enumerate(val_loader):
|
|
val_data, val_label = val_data.float().to(device), val_label.float().to(device)
|
|
val_predict_data = torch.squeeze(model(val_data))
|
|
|
|
loss = criterion(val_predict_data, val_label)
|
|
val_loss += loss.item()
|
|
|
|
scheduler_model.step(val_loss)
|
|
|
|
train_loss = train_loss / len(train_loader)
|
|
val_loss = val_loss / len(val_loader)
|
|
print(
|
|
"[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
|
|
epoch + 1, epochs, time.time() - epoch_start_time,
|
|
train_loss,
|
|
val_loss,
|
|
optimizer_model.state_dict()['param_groups'][0]['lr']))
|
|
|
|
# 保存在验证集上loss最小的模型
|
|
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
|
# 如果精度大于最高精度,则保存
|
|
|
|
if len(val_loss_list) == 0 or val_loss < min(val_loss_list):
|
|
print("保存模型最佳模型成功")
|
|
# 保存模型参数
|
|
# 保存模型参数
|
|
if best_save_path != None:
|
|
utils.delete_file(best_save_path)
|
|
best_save_path = base_save + "_epoch" + str(epoch) + \
|
|
"_trainLoss" + str(train_loss) + \
|
|
"_valLoss" + str(val_loss) + ".pkl"
|
|
torch.save(model.state_dict(),
|
|
best_save_path)
|
|
|
|
train_loss_list.append(train_loss)
|
|
|
|
val_loss_list.append(val_loss)
|
|
|
|
if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience):
|
|
break
|
|
|
|
'''保存的模型参数的路径'''
|
|
return best_save_path
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
begin = time.time()
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:0")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
'''训练'''
|
|
save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000)
|
|
|
|
end = time.time()
|
|
|
|
'''测试'''
|
|
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
|
|
|
|
test(hidden_num, feature, predict_num=predict_num,
|
|
batch_size=batch_size, save_path=save_path,
|
|
is_single=is_single, is_norm=is_norm, save_fig_name=save_fig_name)
|
|
|
|
print("训练耗时:{:3.2f}s".format(end - begin))
|