# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/10 14:56 @Usage : @Desc : 训练LSTM ''' import os import time import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from torch.utils.data import DataLoader from RUL.otherIdea.dctChannelEmbedLSTM.model import PredictModel from RUL.otherIdea.dctChannelEmbedLSTM.loadData import getDataset from RUL.otherIdea.dctChannelEmbedLSTM.test import test from RUL.baseModel.CommonFunction import IsStopTraining import math ''' 超参数设置: ''' hidden_num = 40 # LSTM细胞个数 feature = 10 # 一个点的维度 batch_size = 32 EPOCH = 1000 predict_num = 50 # 预测个数 is_norm = False is_single = True seed = 334 model_name = "dctEmbedLSTM" base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}_seed{4}".format(model_name, hidden_num, feature, predict_num, seed) if not os.path.exists("parameters"): os.makedirs("parameters") def get_dataset(): '''得到数据集''' train_dataset, val_dataset = getDataset( hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm) '''DataLoader''' train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False) return train_loader, val_loader def train(device, lr, lr_patience, early_stop_patience, epochs,seed): '''预测模型''' global best_save_path model = PredictModel(input_dim=feature) '''得到数据集''' train_loader, val_loader = get_dataset() criterion = nn.MSELoss().to(device) optimizer_model = torch.optim.SGD(model.parameters(), lr=lr) scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5, patience=lr_patience) def zero_grad_all(): optimizer_model.zero_grad() train_loss_list = [] val_loss_list = [] for epoch in range(epochs): epoch_start_time = time.time() train_loss = 0.0 val_loss = 0.0 model.train() for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader): train_data, train_label = train_data.to(device), train_label.to(device) zero_grad_all() predict_data = torch.squeeze(model(train_data)) # MSE损失 loss = criterion(predict_data, train_label) loss.backward() optimizer_model.step() zero_grad_all() train_loss += loss.item() model.eval() with torch.no_grad(): for val_batch_idx, (val_data, val_label) in enumerate(val_loader): val_data, val_label = val_data.to(device), val_label.to(device) val_predict_data = torch.squeeze(model(val_data)) loss = criterion(val_predict_data, val_label) val_loss += loss.item() scheduler_model.step(val_loss) train_loss = train_loss / len(train_loader) val_loss = val_loss / len(val_loader) print( "[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format( epoch + 1, epochs, time.time() - epoch_start_time, train_loss, val_loss, optimizer_model.state_dict()['param_groups'][0]['lr'])) # 保存在验证集上loss最小的模型 # if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list): # 如果精度大于最高精度,则保存 if len(val_loss_list) == 0 or val_loss < min(val_loss_list): print("保存模型最佳模型成功") # 保存模型参数 best_save_path = base_save + "_epoch" + str(epoch) + \ "_trainLoss" + str(train_loss) + \ "_valLoss" + str(val_loss) + ".pkl" torch.save(model.state_dict(), best_save_path) train_loss_list.append(train_loss) val_loss_list.append(val_loss) if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience): break '''保存的模型参数的路径''' from matplotlib import rcParams config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) pic1 = plt.figure(figsize=(12, 6), dpi=200) plt.plot(np.arange(1, len(train_loss_list) + 1), train_loss_list, 'b', label='Training Loss') plt.plot(np.arange(1, len(train_loss_list) + 1), val_loss_list, 'r', label='Validation Loss') # plt.ylim(0, 0.08) # 设置y轴范围 plt.title('Training & Validation loss') plt.xlabel('epoch') plt.ylabel('loss') plt.legend(loc='upper right') plt.grid(alpha=0.4) # 获取当前时间戳 timestamp = int(time.time()) # 将时间戳转换为字符串 timestamp_str = str(timestamp) plt.savefig(timestamp_str, dpi=200) return best_save_path if __name__ == '__main__': begin = time.time() if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) '''训练''' save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000, seed=seed) end = time.time() '''测试''' # test1(5, src_condition, tar_condition, G_params_path, LC_params_path) test(hidden_num, feature, predict_num=predict_num, batch_size=batch_size, save_path=save_path, is_single=is_single, is_norm=is_norm) print("训练耗时:{:3.2f}s".format(end - begin))