357 lines
15 KiB
Python
357 lines
15 KiB
Python
"""
|
|
@Author: miykah
|
|
@Email: miykah@163.com
|
|
@FileName: train.py
|
|
@DateTime: 2022/7/20 20:22
|
|
"""
|
|
import os
|
|
import time
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import matplotlib.pyplot as plt
|
|
from torch.utils.data import DataLoader
|
|
from example.load_data import get_dataset, draw_signal_img
|
|
from example.model import Extractor, LabelClassifier, DomainClassifier, weights_init_Classifier, weights_init_Extractor
|
|
from example.loss import cal_mmdc_loss, BinaryCrossEntropyLoss, get_mmdc
|
|
import example.mmd as mmd
|
|
from example.test import test, test1
|
|
from scipy.spatial.distance import cdist
|
|
import math
|
|
|
|
def obtain_label(feature, output, bs):
|
|
with torch.no_grad():
|
|
all_fea = feature.reshape(bs, -1).float().cpu()
|
|
all_output = output.float().cpu()
|
|
# all_label = label.float()
|
|
all_output = nn.Softmax(dim=1)(all_output)
|
|
_, predict = torch.max(all_output, 1)
|
|
# accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
|
|
|
|
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
|
|
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
|
|
all_fea = all_fea.float().cpu().numpy()
|
|
|
|
K = all_output.size(1)
|
|
aff = all_output.float().cpu().numpy()
|
|
initc = aff.transpose().dot(all_fea)
|
|
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
|
|
dd = cdist(all_fea, initc, 'cosine')
|
|
pred_label = dd.argmin(axis=1)
|
|
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
|
|
|
|
for round in range(1):
|
|
aff = np.eye(K)[pred_label]
|
|
initc = aff.transpose().dot(all_fea)
|
|
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
|
|
dd = cdist(all_fea, initc, 'cosine')
|
|
pred_label = dd.argmin(axis=1)
|
|
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
|
|
|
|
# log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
|
|
# args.out_file.write(log_str + '\n')
|
|
# args.out_file.flush()
|
|
# print(log_str + '\n')
|
|
return pred_label.astype('int')
|
|
|
|
|
|
def train(device, src_condition, tar_condition, cls, epochs, bs, shot, lr, patience, gamma, miu):
|
|
'''特征提取器'''
|
|
G = Extractor()
|
|
G.apply(weights_init_Extractor)
|
|
G.to(device)
|
|
'''标签分类器'''
|
|
LC = LabelClassifier(cls_num=cls)
|
|
LC.apply(weights_init_Classifier)
|
|
LC.to(device)
|
|
'''域分类器'''
|
|
DC = DomainClassifier()
|
|
DC.apply(weights_init_Classifier)
|
|
DC.to(device)
|
|
|
|
'''得到数据集'''
|
|
src_dataset, tar_dataset, val_dataset, test_dataset \
|
|
= get_dataset(src_condition, tar_condition)
|
|
'''DataLoader'''
|
|
src_loader = DataLoader(dataset=src_dataset, batch_size=bs, shuffle=False, drop_last=True)
|
|
tar_loader = DataLoader(dataset=tar_dataset, batch_size=bs, shuffle=True, drop_last=True)
|
|
val_loader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, drop_last=False)
|
|
test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False)
|
|
|
|
criterion = nn.CrossEntropyLoss().to(device)
|
|
BCE = BinaryCrossEntropyLoss().to(device)
|
|
|
|
optimizer_g = torch.optim.Adam(G.parameters(), lr=lr)
|
|
optimizer_lc = torch.optim.Adam(LC.parameters(), lr=lr)
|
|
optimizer_dc = torch.optim.Adam(DC.parameters(), lr=lr)
|
|
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=patience)
|
|
scheduler_lc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_lc, mode="min", factor=0.5, patience=patience)
|
|
scheduler_dc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_dc, mode="min", factor=0.5, patience=patience)
|
|
|
|
def zero_grad_all():
|
|
optimizer_g.zero_grad()
|
|
optimizer_lc.zero_grad()
|
|
optimizer_dc.zero_grad()
|
|
|
|
src_acc_list = []
|
|
train_loss_list = []
|
|
val_acc_list = []
|
|
val_loss_list = []
|
|
|
|
for epoch in range(epochs):
|
|
epoch_start_time = time.time()
|
|
src_acc = 0.0
|
|
train_loss = 0.0
|
|
val_acc = 0.0
|
|
val_loss = 0.0
|
|
|
|
G.train()
|
|
LC.train()
|
|
DC.train()
|
|
|
|
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data)) in zip(enumerate(src_loader), enumerate(tar_loader)):
|
|
src_data, src_label = src_data.to(device), src_label.to(device)
|
|
tar_data = tar_data.to(device)
|
|
zero_grad_all()
|
|
|
|
T1 = (int)(0.2 * epochs)
|
|
T2 = (int)(0.5 * epochs)
|
|
|
|
src_feature, tar_feature = G(src_data, tar_data)
|
|
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
|
if epoch < T1:
|
|
pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
|
else:
|
|
pseudo_label = torch.tensor(obtain_label(tar_feature, tar_output, bs), dtype=torch.int64).cuda()
|
|
|
|
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
|
# mmdc权重
|
|
if epoch < T1:
|
|
miu_f = 0
|
|
# elif epoch > T1 and epoch < T2:
|
|
# miu_f = miu * (epoch - T1) / (T2 - T1)
|
|
else:
|
|
miu_f = miu
|
|
|
|
# 源数据的交叉熵损失
|
|
loss_src = criterion(src_output, src_label)
|
|
# 目标数据的伪标签交叉熵损失
|
|
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
|
# mmd损失
|
|
loss_mmdm = mmd.mmd_rbf_noaccelerate(src_data_mmd1, tar_data_mmd1)
|
|
if epoch < T1:
|
|
loss_mmdc = 0
|
|
else:
|
|
loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
|
|
# loss_mmdc = cal_mmdc_loss(src_data_mmd1, src_output, tar_data_mmd1, tar_output)
|
|
# loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
|
|
# loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
|
|
|
|
# 伪标签损失的权重
|
|
# if epoch < T1:
|
|
# beta_f = 0
|
|
# elif epoch > T1 and epoch < T2:
|
|
# beta_f = beta * (epoch - T1) / (T2 - T1)
|
|
# else:
|
|
# beta_f = beta
|
|
|
|
p = epoch / epochs
|
|
lamda = (2 / (1 + math.exp(-10 * p))) - 1
|
|
# gamma = (2 / (1 + math.exp(-10 * p))) - 1
|
|
# miu_f = (2 / (1 + math.exp(-10 * p))) - 1
|
|
|
|
loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
|
|
|
|
loss_G_LC = loss_src + gamma * loss_jmmd
|
|
# loss_G_LC = loss_src + beta_f * loss_tar_pseudo + gamma * loss_jmmd
|
|
loss_G_LC.backward()
|
|
optimizer_g.step()
|
|
optimizer_lc.step()
|
|
zero_grad_all()
|
|
#-----------------------------------------------
|
|
# 对抗域适应的损失
|
|
src_feature, tar_feature = G(src_data, tar_data)
|
|
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
|
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
|
# 源数据的交叉熵损失
|
|
loss_src = criterion(src_output, src_label)
|
|
# 目标数据的伪标签交叉熵损失
|
|
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
|
|
|
gradient_src = \
|
|
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
|
|
only_inputs=True)[0]
|
|
gradient_tar = \
|
|
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
|
|
only_inputs=True)[0]
|
|
gradients_adv = torch.cat((gradient_src, gradient_tar), dim=0)
|
|
|
|
domain_label_reverse = DC(gradients_adv, reverse=True)
|
|
loss_adv_r = BCE(domain_label_reverse)
|
|
loss_G_adv = lamda * loss_adv_r
|
|
# 更新特征提取器G参数
|
|
loss_G_adv.backward()
|
|
optimizer_g.step()
|
|
zero_grad_all()
|
|
#---------------------------------------------------------------------
|
|
src_feature, tar_feature = G(src_data, tar_data)
|
|
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
|
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
|
# 源数据的交叉熵损失
|
|
loss_src = criterion(src_output, src_label)
|
|
# 目标数据的伪标签交叉熵损失
|
|
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
|
|
|
gradient_src = \
|
|
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
|
|
only_inputs=True)[0]
|
|
gradient_tar = \
|
|
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
|
|
only_inputs=True)[0]
|
|
|
|
gradients = torch.cat((gradient_src, gradient_tar), dim=0)
|
|
domain_label = DC(gradients, reverse=False)
|
|
loss_adv = BCE(domain_label)
|
|
loss_DC = lamda * loss_adv
|
|
# 更新域分类器的参数
|
|
loss_DC.backward()
|
|
optimizer_dc.step()
|
|
zero_grad_all()
|
|
|
|
src_acc += np.sum(np.argmax(src_output.cpu().detach().numpy(), axis=1) == src_label.cpu().numpy())
|
|
train_loss += (loss_G_LC + loss_G_adv + loss_DC).item()
|
|
|
|
G.eval()
|
|
LC.eval()
|
|
DC.eval()
|
|
with torch.no_grad():
|
|
for batch_idx, (val_data, val_label) in enumerate(val_loader):
|
|
val_data, val_label = val_data.to(device), val_label.to(device)
|
|
_, val_feature = G(val_data, val_data)
|
|
_, _, _, val_output = LC(val_feature, val_feature)
|
|
loss = criterion(val_output, val_label)
|
|
|
|
val_acc += np.sum(np.argmax(val_output.cpu().detach().numpy(), axis=1) == val_label.cpu().numpy())
|
|
val_loss += loss.item()
|
|
|
|
scheduler_g.step(val_loss)
|
|
scheduler_lc.step(val_loss)
|
|
scheduler_dc.step(val_loss)
|
|
|
|
print("[{:03d}/{:03d}] {:2.2f} sec(s) src_acc: {:3.6f} train_loss: {:3.9f} | val_acc: {:3.6f} val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
|
|
epoch + 1, epochs, time.time() - epoch_start_time, \
|
|
src_acc / src_dataset.__len__(), train_loss / src_dataset.__len__(),
|
|
val_acc / val_dataset.__len__(), val_loss / val_dataset.__len__(),
|
|
optimizer_g.state_dict()['param_groups'][0]['lr']))
|
|
|
|
# 保存在验证集上loss最小的模型
|
|
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
|
# 如果精度大于最高精度,则保存
|
|
if val_acc_list.__len__() > 0 :
|
|
# if (val_acc / val_dataset.__len__()) >= max(val_acc_list):
|
|
if (val_acc / val_dataset.__len__()) > max(val_acc_list) or (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
|
print("保存模型最佳模型成功")
|
|
G_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/G"
|
|
LC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/LC"
|
|
DC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/DC"
|
|
if not os.path.exists(G_path):
|
|
os.makedirs(G_path)
|
|
if not os.path.exists(LC_path):
|
|
os.makedirs(LC_path)
|
|
if not os.path.exists(DC_path):
|
|
os.makedirs(DC_path)
|
|
# 保存模型参数
|
|
torch.save(G.state_dict(), G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
|
torch.save(LC.state_dict(), LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
|
torch.save(DC.state_dict(), DC_path + "/DC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
|
|
|
src_acc_list.append(src_acc / src_dataset.__len__())
|
|
train_loss_list.append(train_loss / src_dataset.__len__())
|
|
val_acc_list.append(val_acc / val_dataset.__len__())
|
|
val_loss_list.append(val_loss / val_dataset.__len__())
|
|
|
|
'''保存的模型参数的路径'''
|
|
G_params_path = G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
|
|
LC_params_path = LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
|
|
|
|
|
|
from matplotlib import rcParams
|
|
|
|
config = {
|
|
"font.family": 'Times New Roman', # 设置字体类型
|
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
|
"axes.labelsize": 13
|
|
}
|
|
rcParams.update(config)
|
|
|
|
# pic1 = plt.figure(figsize=(8, 6), dpi=200)
|
|
# plt.subplot(211)
|
|
# plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc')
|
|
# plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc')
|
|
# plt.ylim(0.3, 1.0) # 设置y轴范围
|
|
# plt.title('Training & Validation accuracy')
|
|
# plt.xlabel('epoch')
|
|
# plt.ylabel('accuracy')
|
|
# plt.legend(loc='lower right')
|
|
# plt.grid(alpha=0.4)
|
|
#
|
|
# plt.subplot(212)
|
|
# plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss')
|
|
# plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss')
|
|
# plt.ylim(0, 0.08) # 设置y轴范围
|
|
# plt.title('Training & Validation loss')
|
|
# plt.xlabel('epoch')
|
|
# plt.ylabel('loss')
|
|
# plt.legend(loc='upper right')
|
|
# plt.grid(alpha=0.4)
|
|
|
|
pic1 = plt.figure(figsize=(12, 6), dpi=200)
|
|
|
|
plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='Training Loss')
|
|
plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='Validation Loss')
|
|
plt.ylim(0, 0.08) # 设置y轴范围
|
|
plt.title('Training & Validation loss')
|
|
plt.xlabel('epoch')
|
|
plt.ylabel('loss')
|
|
plt.legend(loc='upper right')
|
|
plt.grid(alpha=0.4)
|
|
|
|
# 获取当前时间戳
|
|
timestamp = int(time.time())
|
|
|
|
# 将时间戳转换为字符串
|
|
timestamp_str = str(timestamp)
|
|
plt.savefig(timestamp_str, dpi=200)
|
|
|
|
return G_params_path, LC_params_path
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
begin = time.time()
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:0")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
seed = 2
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
src_condition = 'A'
|
|
tar_condition = 'B'
|
|
'''训练'''
|
|
G_params_path, LC_params_path = train(device, src_condition, tar_condition, cls=5,
|
|
epochs=200, bs=50, shot=10, lr=0.002, patience=40, gamma=1, miu=0.5)
|
|
|
|
end = time.time()
|
|
|
|
'''测试'''
|
|
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
|
|
test(5, tar_condition, G_params_path, LC_params_path)
|
|
|
|
print("训练耗时:{:3.2f}s".format(end - begin)) |