self_example/pytorch_example/RUL/otherIdea/dctChannelEmbedLSTM/train.py

195 lines
6.0 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/10 14:56
@Usage :
@Desc : 训练LSTM
'''
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from RUL.otherIdea.dctChannelEmbedLSTM.model import PredictModel
from RUL.otherIdea.dctChannelEmbedLSTM.loadData import getDataset
from RUL.otherIdea.dctChannelEmbedLSTM.test import test
from RUL.baseModel.CommonFunction import IsStopTraining
import math
'''
超参数设置:
'''
hidden_num = 40 # LSTM细胞个数
feature = 10 # 一个点的维度
batch_size = 32
EPOCH = 1000
predict_num = 50 # 预测个数
is_norm = False
is_single = True
seed = 334
model_name = "dctEmbedLSTM"
base_save = r"parameters/{0}_hidden{1}_feature{2}_predict{3}_seed{4}".format(model_name, hidden_num, feature,
predict_num, seed)
if not os.path.exists("parameters"):
os.makedirs("parameters")
def get_dataset():
'''得到数据集'''
train_dataset, val_dataset = getDataset(
hidden_num=hidden_num, feature=feature, predict_num=predict_num, is_single=is_single, is_norm=is_norm)
'''DataLoader'''
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_loader, val_loader
def train(device, lr, lr_patience, early_stop_patience, epochs,seed):
'''预测模型'''
global best_save_path
model = PredictModel(input_dim=feature)
'''得到数据集'''
train_loader, val_loader = get_dataset()
criterion = nn.MSELoss().to(device)
optimizer_model = torch.optim.SGD(model.parameters(), lr=lr)
scheduler_model = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode="min", factor=0.5,
patience=lr_patience)
def zero_grad_all():
optimizer_model.zero_grad()
train_loss_list = []
val_loss_list = []
for epoch in range(epochs):
epoch_start_time = time.time()
train_loss = 0.0
val_loss = 0.0
model.train()
for (train_batch_idx, (train_data, train_label)) in enumerate(train_loader):
train_data, train_label = train_data.to(device), train_label.to(device)
zero_grad_all()
predict_data = torch.squeeze(model(train_data))
# MSE损失
loss = criterion(predict_data, train_label)
loss.backward()
optimizer_model.step()
zero_grad_all()
train_loss += loss.item()
model.eval()
with torch.no_grad():
for val_batch_idx, (val_data, val_label) in enumerate(val_loader):
val_data, val_label = val_data.to(device), val_label.to(device)
val_predict_data = torch.squeeze(model(val_data))
loss = criterion(val_predict_data, val_label)
val_loss += loss.item()
scheduler_model.step(val_loss)
train_loss = train_loss / len(train_loader)
val_loss = val_loss / len(val_loader)
print(
"[{:03d}/{:03d}] {:2.2f} sec(s) train_loss: {:3.9f} | val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
epoch + 1, epochs, time.time() - epoch_start_time,
train_loss,
val_loss,
optimizer_model.state_dict()['param_groups'][0]['lr']))
# 保存在验证集上loss最小的模型
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
# 如果精度大于最高精度,则保存
if len(val_loss_list) == 0 or val_loss < min(val_loss_list):
print("保存模型最佳模型成功")
# 保存模型参数
best_save_path = base_save + "_epoch" + str(epoch) + \
"_trainLoss" + str(train_loss) + \
"_valLoss" + str(val_loss) + ".pkl"
torch.save(model.state_dict(),
best_save_path)
train_loss_list.append(train_loss)
val_loss_list.append(val_loss)
if IsStopTraining(history_loss=val_loss_list, patience=early_stop_patience):
break
'''保存的模型参数的路径'''
from matplotlib import rcParams
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
pic1 = plt.figure(figsize=(12, 6), dpi=200)
plt.plot(np.arange(1, len(train_loss_list) + 1), train_loss_list, 'b', label='Training Loss')
plt.plot(np.arange(1, len(train_loss_list) + 1), val_loss_list, 'r', label='Validation Loss')
# plt.ylim(0, 0.08) # 设置y轴范围
plt.title('Training & Validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')
plt.grid(alpha=0.4)
# 获取当前时间戳
timestamp = int(time.time())
# 将时间戳转换为字符串
timestamp_str = str(timestamp)
plt.savefig(timestamp_str, dpi=200)
return best_save_path
if __name__ == '__main__':
begin = time.time()
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
'''训练'''
save_path = train(device, lr=0.01, lr_patience=10, early_stop_patience=20, epochs=1000, seed=seed)
end = time.time()
'''测试'''
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
test(hidden_num, feature, predict_num=predict_num,
batch_size=batch_size, save_path=save_path,
is_single=is_single, is_norm=is_norm)
print("训练耗时:{:3.2f}s".format(end - begin))