# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/10 14:56 @Usage : @Desc : 训练LSTM ''' import os import time import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from torch.utils.data import DataLoader from RUL.otherIdea.dctLSTM.modelForEasy import PredictModel from RUL.otherIdea.dctLSTM.loadData import getDataset from RUL.otherIdea.dctLSTM.test import test from RUL.baseModel.CommonFunction import IsStopTraining import RUL.baseModel.utils.utils as utils import math ''' 超参数设置: ''' hidden_num = 40 # LSTM细胞个数 feature = 10 # 一个点的维度 batch_size = 32 EPOCH = 1000 predict_num = 200 # 预测个数 seed = 250 is_norm = False is_single = True model_name = "dctLSTM" base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}".format(model_name, hidden_num, feature, predict_num) save_fig_name = 'fig/seed{0}_hidden{1}_feature{2}_predict{3}'.format(seed, hidden_num, feature, predict_num) if not os.path.exists("parameters"): os.makedirs("parameters") if not os.path.exists("fig"): os.makedirs("fig") def get_dataset(): '''得到数据集''' train_dataset, val_dataset = getDataset( hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm) '''DataLoader''' train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False) return train_loader, val_loader def train(device, lr, lr_patience, early_stop_patience, epochs): '''预测模型''' global best_save_path model = PredictModel(input_dim=feature) '''得到数据集''' train_loader, val_loader = get_dataset() criterion = nn.MSELoss().to(device) optimizer_model = torch.optim.SGD(model.parameters(), lr=lr) scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5, patience=lr_patience) def zero_grad_all(): optimizer_model.zero_grad() train_loss_list = [] val_loss_list = [] best_save_path = None for epoch in range(epochs): epoch_start_time = time.time() train_loss = 0.0 val_loss = 0.0 model.train() for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader): train_data, train_label = train_data.to(device), train_label.to(device) zero_grad_all() predict_data = torch.squeeze(model(train_data)) # MSE损失 loss = criterion(predict_data, train_label) loss.backward() optimizer_model.step() zero_grad_all() train_loss += loss.item() model.eval() with torch.no_grad(): for val_batch_idx, (val_data, val_label) in enumerate(val_loader): val_data, val_label = val_data.to(device), val_label.to(device) val_predict_data = torch.squeeze(model(val_data)) loss = criterion(val_predict_data, val_label) val_loss += loss.item() scheduler_model.step(val_loss) train_loss = train_loss / len(train_loader) val_loss = val_loss / len(val_loader) print( "[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format( epoch + 1, epochs, time.time() - epoch_start_time, train_loss, val_loss, optimizer_model.state_dict()['param_groups'][0]['lr'])) # 保存在验证集上loss最小的模型 # if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list): # 如果精度大于最高精度,则保存 if len(val_loss_list) == 0 or val_loss < min(val_loss_list): print("保存模型最佳模型成功") # 保存模型参数 if best_save_path != None: utils.delete_file(best_save_path) best_save_path = base_save + "_epoch" + str(epoch) + \ "_trainLoss" + str(train_loss) + \ "_valLoss" + str(val_loss) + ".pkl" torch.save(model.state_dict(), best_save_path) train_loss_list.append(train_loss) val_loss_list.append(val_loss) if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience): break '''保存的模型参数的路径''' from matplotlib import rcParams config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) pic1 = plt.figure(figsize=(12, 6), dpi=200) plt.plot(np.arange(1, len(train_loss_list) + 1), train_loss_list, 'b', label='Training Loss') plt.plot(np.arange(1, len(train_loss_list) + 1), val_loss_list, 'r', label='Validation Loss') # plt.ylim(0, 0.08) # 设置y轴范围 plt.title('Training & Validation loss') plt.xlabel('epoch') plt.ylabel('loss') plt.legend(loc='upper right') plt.grid(alpha=0.4) # 获取当前时间戳 timestamp = int(time.time()) # 将时间戳转换为字符串 timestamp_str = str(timestamp) plt.savefig(timestamp_str, dpi=200) plt.show() return best_save_path if __name__ == '__main__': begin = time.time() if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) '''训练''' save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000) end = time.time() '''测试''' # test1(5, src_condition, tar_condition, G_params_path, LC_params_path) test(hidden_num, feature, predict_num=predict_num, batch_size=batch_size, save_path=save_path, is_single=is_single, is_norm=is_norm,save_fig_name=save_fig_name) print("训练耗时:{:3.2f}s".format(end - begin))