self_example/pytorch_example/RUL/otherIdea/LSTM/train.py

178 lines
5.5 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/10 14:56
@Usage :
@Desc : 训练LSTM
'''
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from RUL.otherIdea.LSTM.modelForEasy import PredictModel
from RUL.otherIdea.LSTM.loadData import getDataset
from RUL.otherIdea.LSTM.test import test
from RUL.baseModel.CommonFunction import IsStopTraining
from scipy.spatial.distance import cdist
import math
import RUL.baseModel.utils.utils as utils
'''
超参数设置:
'''
hidden_num = 10 # LSTM细胞个数
feature = 10 # 一个点的维度
batch_size = 32
EPOCH = 1000
seed = 5
predict_num = 200 # 预测个数
is_norm = True
is_single = True
model_name = "LSTM"
base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}".format(model_name, hidden_num, feature,
predict_num)
save_fig_name = 'fig/seed{0}_hidden{1}_feature{2}_predict{3}'.format(seed, hidden_num, feature, predict_num)
if not os.path.exists("parameters"):
os.makedirs("parameters")
if not os.path.exists("fig"):
os.makedirs("fig")
def get_dataset():
'''得到数据集'''
train_dataset, val_dataset = getDataset(
hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm)
'''DataLoader'''
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_loader, val_loader
def train(device, lr, lr_patience, early_stop_patience, epochs):
'''预测模型'''
global best_save_path
model = PredictModel(input_dim=feature)
'''得到数据集'''
train_loader, val_loader = get_dataset()
criterion = nn.MSELoss().to(device)
optimizer_model = torch.optim.SGD(model.parameters(), lr=lr)
scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5,
patience=lr_patience)
def zero_grad_all():
optimizer_model.zero_grad()
train_loss_list = []
val_loss_list = []
best_save_path = None
for epoch in range(epochs):
epoch_start_time = time.time()
train_loss = 0.0
val_loss = 0.0
model.train()
for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader):
train_data, train_label = train_data.float().to(device), train_label.float().to(device)
zero_grad_all()
predict_data = torch.squeeze(model(train_data))
# MSE损失
loss = criterion(predict_data, train_label)
loss.backward()
optimizer_model.step()
zero_grad_all()
train_loss += loss.item()
model.eval()
with torch.no_grad():
for val_batch_idx, (val_data, val_label) in enumerate(val_loader):
val_data, val_label = val_data.float().to(device), val_label.float().to(device)
val_predict_data = torch.squeeze(model(val_data))
loss = criterion(val_predict_data, val_label)
val_loss += loss.item()
scheduler_model.step(val_loss)
train_loss = train_loss / len(train_loader)
val_loss = val_loss / len(val_loader)
print(
"[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
epoch + 1, epochs, time.time() - epoch_start_time,
train_loss,
val_loss,
optimizer_model.state_dict()['param_groups'][0]['lr']))
# 保存在验证集上loss最小的模型
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
# 如果精度大于最高精度,则保存
if len(val_loss_list) == 0 or val_loss < min(val_loss_list):
print("保存模型最佳模型成功")
# 保存模型参数
# 保存模型参数
if best_save_path != None:
utils.delete_file(best_save_path)
best_save_path = base_save + "_epoch" + str(epoch) + \
"_trainLoss" + str(train_loss) + \
"_valLoss" + str(val_loss) + ".pkl"
torch.save(model.state_dict(),
best_save_path)
train_loss_list.append(train_loss)
val_loss_list.append(val_loss)
if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience):
break
'''保存的模型参数的路径'''
return best_save_path
if __name__ == '__main__':
begin = time.time()
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
'''训练'''
save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000)
end = time.time()
'''测试'''
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
test(hidden_num, feature, predict_num=predict_num,
batch_size=batch_size, save_path=save_path,
is_single=is_single, is_norm=is_norm, save_fig_name=save_fig_name)
print("训练耗时:{:3.2f}s".format(end - begin))