self_example/pytorch_example/RUL/otherIdea/dctLSTM/train.py

205 lines
6.3 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/10 14:56
@Usage :
@Desc : 训练LSTM
'''
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from RUL.otherIdea.dctLSTM.modelForEasy import PredictModel
from RUL.otherIdea.dctLSTM.loadData import getDataset
from RUL.otherIdea.dctLSTM.test import test
from RUL.baseModel.CommonFunction import IsStopTraining
import RUL.baseModel.utils.utils as utils
import math
'''
超参数设置:
'''
hidden_num = 40 # LSTM细胞个数
feature = 10 # 一个点的维度
batch_size = 32
EPOCH = 1000
predict_num = 200 # 预测个数
seed = 250
is_norm = False
is_single = True
model_name = "dctLSTM"
base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}".format(model_name, hidden_num, feature,
predict_num)
save_fig_name = 'fig/seed{0}_hidden{1}_feature{2}_predict{3}'.format(seed, hidden_num, feature, predict_num)
if not os.path.exists("parameters"):
os.makedirs("parameters")
if not os.path.exists("fig"):
os.makedirs("fig")
def get_dataset():
'''得到数据集'''
train_dataset, val_dataset = getDataset(
hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm)
'''DataLoader'''
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_loader, val_loader
def train(device, lr, lr_patience, early_stop_patience, epochs):
'''预测模型'''
global best_save_path
model = PredictModel(input_dim=feature)
'''得到数据集'''
train_loader, val_loader = get_dataset()
criterion = nn.MSELoss().to(device)
optimizer_model = torch.optim.SGD(model.parameters(), lr=lr)
scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5,
patience=lr_patience)
def zero_grad_all():
optimizer_model.zero_grad()
train_loss_list = []
val_loss_list = []
best_save_path = None
for epoch in range(epochs):
epoch_start_time = time.time()
train_loss = 0.0
val_loss = 0.0
model.train()
for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader):
train_data, train_label = train_data.to(device), train_label.to(device)
zero_grad_all()
predict_data = torch.squeeze(model(train_data))
# MSE损失
loss = criterion(predict_data, train_label)
loss.backward()
optimizer_model.step()
zero_grad_all()
train_loss += loss.item()
model.eval()
with torch.no_grad():
for val_batch_idx, (val_data, val_label) in enumerate(val_loader):
val_data, val_label = val_data.to(device), val_label.to(device)
val_predict_data = torch.squeeze(model(val_data))
loss = criterion(val_predict_data, val_label)
val_loss += loss.item()
scheduler_model.step(val_loss)
train_loss = train_loss / len(train_loader)
val_loss = val_loss / len(val_loader)
print(
"[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
epoch + 1, epochs, time.time() - epoch_start_time,
train_loss,
val_loss,
optimizer_model.state_dict()['param_groups'][0]['lr']))
# 保存在验证集上loss最小的模型
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
# 如果精度大于最高精度,则保存
if len(val_loss_list) == 0 or val_loss < min(val_loss_list):
print("保存模型最佳模型成功")
# 保存模型参数
if best_save_path != None:
utils.delete_file(best_save_path)
best_save_path = base_save + "_epoch" + str(epoch) + \
"_trainLoss" + str(train_loss) + \
"_valLoss" + str(val_loss) + ".pkl"
torch.save(model.state_dict(),
best_save_path)
train_loss_list.append(train_loss)
val_loss_list.append(val_loss)
if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience):
break
'''保存的模型参数的路径'''
from matplotlib import rcParams
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
pic1 = plt.figure(figsize=(12, 6), dpi=200)
plt.plot(np.arange(1, len(train_loss_list) + 1), train_loss_list, 'b', label='Training Loss')
plt.plot(np.arange(1, len(train_loss_list) + 1), val_loss_list, 'r', label='Validation Loss')
# plt.ylim(0, 0.08) # 设置y轴范围
plt.title('Training & Validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')
plt.grid(alpha=0.4)
# 获取当前时间戳
timestamp = int(time.time())
# 将时间戳转换为字符串
timestamp_str = str(timestamp)
plt.savefig(timestamp_str, dpi=200)
plt.show()
return best_save_path
if __name__ == '__main__':
begin = time.time()
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
'''训练'''
save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000)
end = time.time()
'''测试'''
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
test(hidden_num, feature, predict_num=predict_num,
batch_size=batch_size, save_path=save_path,
is_single=is_single, is_norm=is_norm,save_fig_name=save_fig_name)
print("训练耗时:{:3.2f}s".format(end - begin))