# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/10 14:56 @Usage : @Desc : 训练LSTM ''' import os import time import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from torch.utils.data import DataLoader from RUL.otherIdea.LSTM.modelForEasy import PredictModel from RUL.otherIdea.LSTM.loadData import getDataset from RUL.otherIdea.LSTM.test import test from RUL.baseModel.CommonFunction import IsStopTraining from scipy.spatial.distance import cdist import math import RUL.baseModel.utils.utils as utils ''' 超参数设置: ''' hidden_num = 10 # LSTM细胞个数 feature = 10 # 一个点的维度 batch_size = 32 EPOCH = 1000 seed = 5 predict_num = 200 # 预测个数 is_norm = True is_single = True model_name = "LSTM" base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}".format(model_name, hidden_num, feature, predict_num) save_fig_name = 'fig/seed{0}_hidden{1}_feature{2}_predict{3}'.format(seed, hidden_num, feature, predict_num) if not os.path.exists("parameters"): os.makedirs("parameters") if not os.path.exists("fig"): os.makedirs("fig") def get_dataset(): '''得到数据集''' train_dataset, val_dataset = getDataset( hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm) '''DataLoader''' train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False) return train_loader, val_loader def train(device, lr, lr_patience, early_stop_patience, epochs): '''预测模型''' global best_save_path model = PredictModel(input_dim=feature) '''得到数据集''' train_loader, val_loader = get_dataset() criterion = nn.MSELoss().to(device) optimizer_model = torch.optim.SGD(model.parameters(), lr=lr) scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5, patience=lr_patience) def zero_grad_all(): optimizer_model.zero_grad() train_loss_list = [] val_loss_list = [] best_save_path = None for epoch in range(epochs): epoch_start_time = time.time() train_loss = 0.0 val_loss = 0.0 model.train() for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader): train_data, train_label = train_data.float().to(device), train_label.float().to(device) zero_grad_all() predict_data = torch.squeeze(model(train_data)) # MSE损失 loss = criterion(predict_data, train_label) loss.backward() optimizer_model.step() zero_grad_all() train_loss += loss.item() model.eval() with torch.no_grad(): for val_batch_idx, (val_data, val_label) in enumerate(val_loader): val_data, val_label = val_data.float().to(device), val_label.float().to(device) val_predict_data = torch.squeeze(model(val_data)) loss = criterion(val_predict_data, val_label) val_loss += loss.item() scheduler_model.step(val_loss) train_loss = train_loss / len(train_loader) val_loss = val_loss / len(val_loader) print( "[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format( epoch + 1, epochs, time.time() - epoch_start_time, train_loss, val_loss, optimizer_model.state_dict()['param_groups'][0]['lr'])) # 保存在验证集上loss最小的模型 # if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list): # 如果精度大于最高精度,则保存 if len(val_loss_list) == 0 or val_loss < min(val_loss_list): print("保存模型最佳模型成功") # 保存模型参数 # 保存模型参数 if best_save_path != None: utils.delete_file(best_save_path) best_save_path = base_save + "_epoch" + str(epoch) + \ "_trainLoss" + str(train_loss) + \ "_valLoss" + str(val_loss) + ".pkl" torch.save(model.state_dict(), best_save_path) train_loss_list.append(train_loss) val_loss_list.append(val_loss) if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience): break '''保存的模型参数的路径''' return best_save_path if __name__ == '__main__': begin = time.time() if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) '''训练''' save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000) end = time.time() '''测试''' # test1(5, src_condition, tar_condition, G_params_path, LC_params_path) test(hidden_num, feature, predict_num=predict_num, batch_size=batch_size, save_path=save_path, is_single=is_single, is_norm=is_norm, save_fig_name=save_fig_name) print("训练耗时:{:3.2f}s".format(end - begin))