""" @Author: miykah @Email: miykah@163.com @FileName: train.py @DateTime: 2022/7/20 20:22 """ import os import time import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt from torch.utils.data import DataLoader from example.load_data import get_dataset, draw_signal_img from example.model import Extractor, LabelClassifier, DomainClassifier, weights_init_Classifier, weights_init_Extractor from example.loss import cal_mmdc_loss, BinaryCrossEntropyLoss, get_mmdc import example.mmd as mmd from example.test import test, test1 from scipy.spatial.distance import cdist import math def obtain_label(feature, output, bs): with torch.no_grad(): all_fea = feature.reshape(bs, -1).float().cpu() all_output = output.float().cpu() # all_label = label.float() all_output = nn.Softmax(dim=1)(all_output) _, predict = torch.max(all_output, 1) # accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() all_fea = all_fea.float().cpu().numpy() K = all_output.size(1) aff = all_output.float().cpu().numpy() initc = aff.transpose().dot(all_fea) initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) dd = cdist(all_fea, initc, 'cosine') pred_label = dd.argmin(axis=1) # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) for round in range(1): aff = np.eye(K)[pred_label] initc = aff.transpose().dot(all_fea) initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) dd = cdist(all_fea, initc, 'cosine') pred_label = dd.argmin(axis=1) # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) # log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) # args.out_file.write(log_str + '\n') # args.out_file.flush() # print(log_str + '\n') return pred_label.astype('int') def train(device, src_condition, tar_condition, cls, epochs, bs, shot, lr, patience, gamma, miu): '''特征提取器''' G = Extractor() G.apply(weights_init_Extractor) G.to(device) '''标签分类器''' LC = LabelClassifier(cls_num=cls) LC.apply(weights_init_Classifier) LC.to(device) '''域分类器''' DC = DomainClassifier() DC.apply(weights_init_Classifier) DC.to(device) '''得到数据集''' src_dataset, tar_dataset, val_dataset, test_dataset \ = get_dataset(src_condition, tar_condition) '''DataLoader''' src_loader = DataLoader(dataset=src_dataset, batch_size=bs, shuffle=False, drop_last=True) tar_loader = DataLoader(dataset=tar_dataset, batch_size=bs, shuffle=True, drop_last=True) val_loader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, drop_last=False) test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False) criterion = nn.CrossEntropyLoss().to(device) BCE = BinaryCrossEntropyLoss().to(device) optimizer_g = torch.optim.Adam(G.parameters(), lr=lr) optimizer_lc = torch.optim.Adam(LC.parameters(), lr=lr) optimizer_dc = torch.optim.Adam(DC.parameters(), lr=lr) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=patience) scheduler_lc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_lc, mode="min", factor=0.5, patience=patience) scheduler_dc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_dc, mode="min", factor=0.5, patience=patience) def zero_grad_all(): optimizer_g.zero_grad() optimizer_lc.zero_grad() optimizer_dc.zero_grad() src_acc_list = [] train_loss_list = [] val_acc_list = [] val_loss_list = [] for epoch in range(epochs): epoch_start_time = time.time() src_acc = 0.0 train_loss = 0.0 val_acc = 0.0 val_loss = 0.0 G.train() LC.train() DC.train() for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data)) in zip(enumerate(src_loader), enumerate(tar_loader)): src_data, src_label = src_data.to(device), src_label.to(device) tar_data = tar_data.to(device) zero_grad_all() T1 = (int)(0.2 * epochs) T2 = (int)(0.5 * epochs) src_feature, tar_feature = G(src_data, tar_data) src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature) if epoch < T1: pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 else: pseudo_label = torch.tensor(obtain_label(tar_feature, tar_output, bs), dtype=torch.int64).cuda() # pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 # mmdc权重 if epoch < T1: miu_f = 0 # elif epoch > T1 and epoch < T2: # miu_f = miu * (epoch - T1) / (T2 - T1) else: miu_f = miu # 源数据的交叉熵损失 loss_src = criterion(src_output, src_label) # 目标数据的伪标签交叉熵损失 loss_tar_pseudo = criterion(tar_output, pseudo_label) # mmd损失 loss_mmdm = mmd.mmd_rbf_noaccelerate(src_data_mmd1, tar_data_mmd1) if epoch < T1: loss_mmdc = 0 else: loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls) # loss_mmdc = cal_mmdc_loss(src_data_mmd1, src_output, tar_data_mmd1, tar_output) # loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls) # loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm # 伪标签损失的权重 # if epoch < T1: # beta_f = 0 # elif epoch > T1 and epoch < T2: # beta_f = beta * (epoch - T1) / (T2 - T1) # else: # beta_f = beta p = epoch / epochs lamda = (2 / (1 + math.exp(-10 * p))) - 1 # gamma = (2 / (1 + math.exp(-10 * p))) - 1 # miu_f = (2 / (1 + math.exp(-10 * p))) - 1 loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm loss_G_LC = loss_src + gamma * loss_jmmd # loss_G_LC = loss_src + beta_f * loss_tar_pseudo + gamma * loss_jmmd loss_G_LC.backward() optimizer_g.step() optimizer_lc.step() zero_grad_all() #----------------------------------------------- # 对抗域适应的损失 src_feature, tar_feature = G(src_data, tar_data) src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature) # pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 # 源数据的交叉熵损失 loss_src = criterion(src_output, src_label) # 目标数据的伪标签交叉熵损失 loss_tar_pseudo = criterion(tar_output, pseudo_label) gradient_src = \ torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_tar = \ torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True, only_inputs=True)[0] gradients_adv = torch.cat((gradient_src, gradient_tar), dim=0) domain_label_reverse = DC(gradients_adv, reverse=True) loss_adv_r = BCE(domain_label_reverse) loss_G_adv = lamda * loss_adv_r # 更新特征提取器G参数 loss_G_adv.backward() optimizer_g.step() zero_grad_all() #--------------------------------------------------------------------- src_feature, tar_feature = G(src_data, tar_data) src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature) # pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 # 源数据的交叉熵损失 loss_src = criterion(src_output, src_label) # 目标数据的伪标签交叉熵损失 loss_tar_pseudo = criterion(tar_output, pseudo_label) gradient_src = \ torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_tar = \ torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = torch.cat((gradient_src, gradient_tar), dim=0) domain_label = DC(gradients, reverse=False) loss_adv = BCE(domain_label) loss_DC = lamda * loss_adv # 更新域分类器的参数 loss_DC.backward() optimizer_dc.step() zero_grad_all() src_acc += np.sum(np.argmax(src_output.cpu().detach().numpy(), axis=1) == src_label.cpu().numpy()) train_loss += (loss_G_LC + loss_G_adv + loss_DC).item() G.eval() LC.eval() DC.eval() with torch.no_grad(): for batch_idx, (val_data, val_label) in enumerate(val_loader): val_data, val_label = val_data.to(device), val_label.to(device) _, val_feature = G(val_data, val_data) _, _, _, val_output = LC(val_feature, val_feature) loss = criterion(val_output, val_label) val_acc += np.sum(np.argmax(val_output.cpu().detach().numpy(), axis=1) == val_label.cpu().numpy()) val_loss += loss.item() scheduler_g.step(val_loss) scheduler_lc.step(val_loss) scheduler_dc.step(val_loss) print("[{:03d}/{:03d}] {:2.2f} sec(s) src_acc: {:3.6f} train_loss: {:3.9f} | val_acc: {:3.6f} val_loss: {:3.9f} | Learning rate : {:3.6f}".format( epoch + 1, epochs, time.time() - epoch_start_time, \ src_acc / src_dataset.__len__(), train_loss / src_dataset.__len__(), val_acc / val_dataset.__len__(), val_loss / val_dataset.__len__(), optimizer_g.state_dict()['param_groups'][0]['lr'])) # 保存在验证集上loss最小的模型 # if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list): # 如果精度大于最高精度,则保存 if val_acc_list.__len__() > 0 : # if (val_acc / val_dataset.__len__()) >= max(val_acc_list): if (val_acc / val_dataset.__len__()) > max(val_acc_list) or (val_loss / val_dataset.__len__()) < min(val_loss_list): print("保存模型最佳模型成功") G_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/G" LC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/LC" DC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/DC" if not os.path.exists(G_path): os.makedirs(G_path) if not os.path.exists(LC_path): os.makedirs(LC_path) if not os.path.exists(DC_path): os.makedirs(DC_path) # 保存模型参数 torch.save(G.state_dict(), G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl") torch.save(LC.state_dict(), LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl") torch.save(DC.state_dict(), DC_path + "/DC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl") src_acc_list.append(src_acc / src_dataset.__len__()) train_loss_list.append(train_loss / src_dataset.__len__()) val_acc_list.append(val_acc / val_dataset.__len__()) val_loss_list.append(val_loss / val_dataset.__len__()) '''保存的模型参数的路径''' G_params_path = G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl" LC_params_path = LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl" from matplotlib import rcParams config = { "font.family": 'Times New Roman', # 设置字体类型 "axes.unicode_minus": False, # 解决负号无法显示的问题 "axes.labelsize": 13 } rcParams.update(config) # pic1 = plt.figure(figsize=(8, 6), dpi=200) # plt.subplot(211) # plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc') # plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc') # plt.ylim(0.3, 1.0) # 设置y轴范围 # plt.title('Training & Validation accuracy') # plt.xlabel('epoch') # plt.ylabel('accuracy') # plt.legend(loc='lower right') # plt.grid(alpha=0.4) # # plt.subplot(212) # plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss') # plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss') # plt.ylim(0, 0.08) # 设置y轴范围 # plt.title('Training & Validation loss') # plt.xlabel('epoch') # plt.ylabel('loss') # plt.legend(loc='upper right') # plt.grid(alpha=0.4) pic1 = plt.figure(figsize=(12, 6), dpi=200) plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='Training Loss') plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='Validation Loss') plt.ylim(0, 0.08) # 设置y轴范围 plt.title('Training & Validation loss') plt.xlabel('epoch') plt.ylabel('loss') plt.legend(loc='upper right') plt.grid(alpha=0.4) # 获取当前时间戳 timestamp = int(time.time()) # 将时间戳转换为字符串 timestamp_str = str(timestamp) plt.savefig(timestamp_str, dpi=200) return G_params_path, LC_params_path if __name__ == '__main__': begin = time.time() if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") seed = 2 torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) src_condition = 'A' tar_condition = 'B' '''训练''' G_params_path, LC_params_path = train(device, src_condition, tar_condition, cls=5, epochs=200, bs=50, shot=10, lr=0.002, patience=40, gamma=1, miu=0.5) end = time.time() '''测试''' # test1(5, src_condition, tar_condition, G_params_path, LC_params_path) test(5, tar_condition, G_params_path, LC_params_path) print("训练耗时:{:3.2f}s".format(end - begin))