pytorch更新
This commit is contained in:
parent
564ba3f669
commit
2c3e6c25a8
|
|
@ -0,0 +1,8 @@
|
||||||
|
#-*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2023/11/9 21:32
|
||||||
|
@Usage :
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#-*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2023/11/9 21:33
|
||||||
|
@Usage :
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#-*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2023/11/9 21:34
|
||||||
|
@Usage :
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#-*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2023/11/9 21:34
|
||||||
|
@Usage :
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#-*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2023/11/9 21:34
|
||||||
|
@Usage :
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#-*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
@Author : dingjiawen
|
||||||
|
@Date : 2023/11/9 21:28
|
||||||
|
@Usage :
|
||||||
|
@Desc :
|
||||||
|
'''
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""
|
||||||
|
@Author: miykah
|
||||||
|
@Email: miykah@163.com
|
||||||
|
@FileName: load_data.py
|
||||||
|
@DateTime: 2022/7/20 16:40
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import random
|
||||||
|
|
||||||
|
'''正常Dataset类'''
|
||||||
|
class Nor_Dataset(Dataset):
|
||||||
|
def __init__(self, datas, labels=None):
|
||||||
|
self.datas = torch.tensor(datas)
|
||||||
|
if labels is not None:
|
||||||
|
self.labels = torch.tensor(labels)
|
||||||
|
else:
|
||||||
|
self.labels = None
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = self.datas[index]
|
||||||
|
if self.labels is not None:
|
||||||
|
label = self.labels[index]
|
||||||
|
return data, label
|
||||||
|
return data
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.datas)
|
||||||
|
|
||||||
|
'''未标记目标数据的Dataset类'''
|
||||||
|
class Tar_U_Dataset(Dataset):
|
||||||
|
def __init__(self, datas):
|
||||||
|
self.datas = torch.tensor(datas)
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = self.datas[index]
|
||||||
|
data_bar = data.clone().detach()
|
||||||
|
data_bar2 = data.clone().detach()
|
||||||
|
mu = 0
|
||||||
|
sigma = 0.1
|
||||||
|
'''对未标记目标数据加噪,得到data'和data'' '''
|
||||||
|
for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])):
|
||||||
|
data_bar[0, i] += random.gauss(mu, sigma)
|
||||||
|
data_bar2[0, j] += random.gauss(mu, sigma)
|
||||||
|
return data, data_bar, data_bar2
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.datas)
|
||||||
|
|
||||||
|
def draw_signal_img(data, data_bar, data_bar2):
|
||||||
|
pic = plt.figure(figsize=(12, 6), dpi=100)
|
||||||
|
plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
plt.subplot(3, 1, 1)
|
||||||
|
plt.plot(data, 'b')
|
||||||
|
plt.subplot(3, 1, 2)
|
||||||
|
plt.plot(data_bar, 'b')
|
||||||
|
plt.subplot(3, 1, 3)
|
||||||
|
plt.plot(data_bar2, 'b')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def get_dataset(src_condition, tar_condition):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||||
|
|
||||||
|
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
|
||||||
|
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
|
||||||
|
|
||||||
|
tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy')
|
||||||
|
|
||||||
|
val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
|
||||||
|
val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
|
||||||
|
|
||||||
|
test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy')
|
||||||
|
test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy')
|
||||||
|
|
||||||
|
src_dataset = Nor_Dataset(src_data, src_label)
|
||||||
|
tar_dataset = Nor_Dataset(tar_data)
|
||||||
|
val_dataset = Nor_Dataset(val_data, val_label)
|
||||||
|
test_dataset = Nor_Dataset(test_data, test_label)
|
||||||
|
|
||||||
|
return src_dataset, tar_dataset, val_dataset, test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# pass
|
||||||
|
# tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy")
|
||||||
|
# tar_dataset = Nor_Dataset(datas=tar_data)
|
||||||
|
# tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False)
|
||||||
|
#
|
||||||
|
# for batch_idx, (data) in enumerate(tar_loader):
|
||||||
|
# print(data[0][0])
|
||||||
|
# # print(data_bar[0][0])
|
||||||
|
# # print(data_bar2[0][0])
|
||||||
|
# if batch_idx == 0:
|
||||||
|
# break
|
||||||
|
src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy")
|
||||||
|
src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy")
|
||||||
|
src_dataset = Nor_Dataset(datas=src_data, labels=src_label)
|
||||||
|
src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False)
|
||||||
|
|
||||||
|
for batch_idx, (data, label) in enumerate(src_loader):
|
||||||
|
print(data[10][0])
|
||||||
|
print(label[10])
|
||||||
|
# print(data_bar[0][0])
|
||||||
|
# print(data_bar2[0][0])
|
||||||
|
if batch_idx == 0:
|
||||||
|
break
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
"""
|
||||||
|
@Author: miykah
|
||||||
|
@Email: miykah@163.com
|
||||||
|
@FileName: loss.py
|
||||||
|
@DateTime: 2022/7/21 14:31
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as functional
|
||||||
|
import torch.nn as nn
|
||||||
|
import example.mmd as mmd
|
||||||
|
|
||||||
|
'''
|
||||||
|
计算条件概率分布的差异
|
||||||
|
mmdc = mmd_condition
|
||||||
|
'''
|
||||||
|
def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2):
|
||||||
|
src_cls_10 = src_data_mmd1[3 * 0: 3 * 1]
|
||||||
|
src_cls_11 = src_data_mmd1[3 * 1: 3 * 2]
|
||||||
|
src_cls_12 = src_data_mmd1[3 * 2: 3 * 3]
|
||||||
|
src_cls_13 = src_data_mmd1[3 * 3: 3 * 4]
|
||||||
|
src_cls_14 = src_data_mmd1[3 * 4: 3 * 5]
|
||||||
|
src_cls_15 = src_data_mmd1[3 * 5: 3 * 6]
|
||||||
|
src_cls_16 = src_data_mmd1[3 * 6: 3 * 7]
|
||||||
|
src_cls_17 = src_data_mmd1[3 * 7: 3 * 8]
|
||||||
|
src_cls_18 = src_data_mmd1[3 * 8: 3 * 9]
|
||||||
|
src_cls_19 = src_data_mmd1[3 * 9: 3 * 10]
|
||||||
|
|
||||||
|
tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1]
|
||||||
|
tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2]
|
||||||
|
tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3]
|
||||||
|
tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4]
|
||||||
|
tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5]
|
||||||
|
tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6]
|
||||||
|
tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7]
|
||||||
|
tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8]
|
||||||
|
tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9]
|
||||||
|
tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10]
|
||||||
|
|
||||||
|
|
||||||
|
src_cls_20 = src_data_mmd2[3 * 0: 3 * 1]
|
||||||
|
src_cls_21 = src_data_mmd2[3 * 1: 3 * 2]
|
||||||
|
src_cls_22 = src_data_mmd2[3 * 2: 3 * 3]
|
||||||
|
src_cls_23 = src_data_mmd2[3 * 3: 3 * 4]
|
||||||
|
src_cls_24 = src_data_mmd2[3 * 4: 3 * 5]
|
||||||
|
src_cls_25 = src_data_mmd2[3 * 5: 3 * 6]
|
||||||
|
src_cls_26 = src_data_mmd2[3 * 6: 3 * 7]
|
||||||
|
src_cls_27 = src_data_mmd2[3 * 7: 3 * 8]
|
||||||
|
src_cls_28 = src_data_mmd2[3 * 8: 3 * 9]
|
||||||
|
src_cls_29 = src_data_mmd2[3 * 9: 3 * 10]
|
||||||
|
|
||||||
|
tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1]
|
||||||
|
tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2]
|
||||||
|
tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3]
|
||||||
|
tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4]
|
||||||
|
tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5]
|
||||||
|
tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6]
|
||||||
|
tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7]
|
||||||
|
tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8]
|
||||||
|
tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9]
|
||||||
|
tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10]
|
||||||
|
|
||||||
|
mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10)
|
||||||
|
mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11)
|
||||||
|
mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12)
|
||||||
|
mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13)
|
||||||
|
mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14)
|
||||||
|
mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15)
|
||||||
|
mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16)
|
||||||
|
mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17)
|
||||||
|
mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18)
|
||||||
|
mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19)
|
||||||
|
|
||||||
|
mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20)
|
||||||
|
mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21)
|
||||||
|
mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22)
|
||||||
|
mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23)
|
||||||
|
mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24)
|
||||||
|
mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25)
|
||||||
|
mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26)
|
||||||
|
mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27)
|
||||||
|
mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28)
|
||||||
|
mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29)
|
||||||
|
|
||||||
|
mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19
|
||||||
|
mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29
|
||||||
|
|
||||||
|
return (mmdc2) / 10
|
||||||
|
# return (mmdc1 + mmdc2) / 20
|
||||||
|
|
||||||
|
'''得到源域每类特征,用于计算mmdc'''
|
||||||
|
def get_src_mean_feature(src_feature, shot, cls):
|
||||||
|
src_feature_list = []
|
||||||
|
for i in range(cls):
|
||||||
|
src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0)
|
||||||
|
src_feature_list.append(src_feature_cls)
|
||||||
|
return src_feature_list
|
||||||
|
|
||||||
|
def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls):
|
||||||
|
src_feature_list = get_src_mean_feature(src_feature, shot, cls)
|
||||||
|
pseudo_label = tar_pseudo_label.cpu().detach().numpy()
|
||||||
|
mmdc = 0.0
|
||||||
|
for i in range(batch_size):
|
||||||
|
# mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
|
||||||
|
mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
|
||||||
|
return mmdc / batch_size
|
||||||
|
|
||||||
|
class BCE(nn.Module):
|
||||||
|
eps = 1e-7
|
||||||
|
def forward(self, prob1, prob2, simi):
|
||||||
|
P = prob1.mul_(prob2)
|
||||||
|
P = P.sum(1)
|
||||||
|
P.mul_(simi).add_(simi.eq(-1).type_as(P))
|
||||||
|
neglogP = -P.add_(BCE.eps).log_()
|
||||||
|
return neglogP.mean()
|
||||||
|
|
||||||
|
class BinaryCrossEntropyLoss(nn.Module):
|
||||||
|
""" Construct binary cross-entropy loss."""
|
||||||
|
eps = 1e-7
|
||||||
|
def forward(self, prob):
|
||||||
|
# ds = torch.ones([bs, 1]).to(device) # domain label for source
|
||||||
|
# dt = torch.zeros([bs, 1]).to(device) # domain label for target
|
||||||
|
# di = torch.cat((ds, dt), dim=0).to(device)
|
||||||
|
# neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps))
|
||||||
|
neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps))
|
||||||
|
return neglogP.mean()
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# encoding: utf-8
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Consider linear time MMD with a linear kernel:
|
||||||
|
# K(f(x), f(y)) = f(x)^Tf(y)
|
||||||
|
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
|
||||||
|
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
|
||||||
|
#
|
||||||
|
# f_of_X: batch_size * k
|
||||||
|
# f_of_Y: batch_size * k
|
||||||
|
def mmd_linear(f_of_X, f_of_Y):
|
||||||
|
delta = f_of_X - f_of_Y
|
||||||
|
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||||
|
n_samples = int(source.size()[0])+int(target.size()[0])
|
||||||
|
total = torch.cat([source, target], dim=0)
|
||||||
|
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||||
|
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||||
|
L2_distance = ((total0-total1)**2).sum(2)
|
||||||
|
if fix_sigma:
|
||||||
|
bandwidth = fix_sigma
|
||||||
|
else:
|
||||||
|
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
|
||||||
|
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||||
|
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
||||||
|
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||||
|
return sum(kernel_val)#/len(kernel_val)
|
||||||
|
|
||||||
|
|
||||||
|
def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||||
|
batch_size = int(source.size()[0])
|
||||||
|
kernels = guassian_kernel(source, target,
|
||||||
|
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
|
||||||
|
loss = 0
|
||||||
|
for i in range(batch_size):
|
||||||
|
s1, s2 = i, (i+1)%batch_size
|
||||||
|
t1, t2 = s1+batch_size, s2+batch_size
|
||||||
|
loss += kernels[s1, s2] + kernels[t1, t2]
|
||||||
|
loss -= kernels[s1, t2] + kernels[s2, t1]
|
||||||
|
return loss / float(batch_size)
|
||||||
|
|
||||||
|
def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||||
|
batch_size = int(source.size()[0])
|
||||||
|
kernels = guassian_kernel(source, target,
|
||||||
|
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
|
||||||
|
XX = kernels[:batch_size, :batch_size]
|
||||||
|
YY = kernels[batch_size:, batch_size:]
|
||||||
|
XY = kernels[:batch_size, batch_size:]
|
||||||
|
YX = kernels[batch_size:, :batch_size]
|
||||||
|
loss = torch.mean(XX + YY - XY -YX)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""
|
||||||
|
@Author: miykah
|
||||||
|
@Email: miykah@163.com
|
||||||
|
@FileName: model.py
|
||||||
|
@DateTime: 2022/7/20 21:18
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Function
|
||||||
|
|
||||||
|
class GradReverse(Function):
|
||||||
|
def __init__(self):
|
||||||
|
self.lambd = 1.0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view_as(x)
|
||||||
|
|
||||||
|
def backward(self, grad_output):
|
||||||
|
return (grad_output * - 1.0)
|
||||||
|
|
||||||
|
def grad_reverse(x):
|
||||||
|
return GradReverse.apply(x)
|
||||||
|
# return GradReverse(lambd)(x)
|
||||||
|
|
||||||
|
'''特征提取器'''
|
||||||
|
class Extractor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Extractor, self).__init__()
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'),
|
||||||
|
nn.BatchNorm1d(32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool1d(2) # (32 * 1024)
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Sequential(
|
||||||
|
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
|
||||||
|
nn.BatchNorm1d(32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool1d(2) # (32 * 512)
|
||||||
|
)
|
||||||
|
self.conv3 = nn.Sequential(
|
||||||
|
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
|
||||||
|
nn.BatchNorm1d(32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool1d(2) # (32 * 256)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, src_data, tar_data):
|
||||||
|
src_data = self.conv1(src_data)
|
||||||
|
src_data = self.conv2(src_data)
|
||||||
|
src_feature = self.conv3(src_data)
|
||||||
|
|
||||||
|
tar_data = self.conv1(tar_data)
|
||||||
|
tar_data = self.conv2(tar_data)
|
||||||
|
tar_feature = self.conv3(tar_data)
|
||||||
|
return src_feature, tar_feature
|
||||||
|
|
||||||
|
'''标签分类器'''
|
||||||
|
class LabelClassifier(nn.Module):
|
||||||
|
def __init__(self, cls_num):
|
||||||
|
super(LabelClassifier, self).__init__()
|
||||||
|
self.fc1 = nn.Sequential(
|
||||||
|
nn.Flatten(), # (8192,)
|
||||||
|
nn.Linear(in_features=8192, out_features=256),
|
||||||
|
)
|
||||||
|
self.fc2 = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(in_features=256, out_features=cls_num)
|
||||||
|
)
|
||||||
|
def forward(self, src_feature, tar_feature):
|
||||||
|
src_data_mmd1 = self.fc1(src_feature)
|
||||||
|
src_output = self.fc2(src_data_mmd1)
|
||||||
|
|
||||||
|
tar_data_mmd1 = self.fc1(tar_feature)
|
||||||
|
tar_output = self.fc2(tar_data_mmd1)
|
||||||
|
return src_data_mmd1, src_output, tar_data_mmd1, tar_output
|
||||||
|
|
||||||
|
'''分类器'''
|
||||||
|
class DomainClassifier(nn.Module):
|
||||||
|
def __init__(self, temp=0.05):
|
||||||
|
super(DomainClassifier, self).__init__()
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(in_features=8192, out_features=512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(in_features=512, out_features=128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(in_features=128, out_features=1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
self.temp = temp
|
||||||
|
def forward(self, x, reverse=False):
|
||||||
|
if reverse:
|
||||||
|
x = grad_reverse(x)
|
||||||
|
output = self.fc(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
'''初始化网络权重'''
|
||||||
|
def weights_init_Extractor(m):
|
||||||
|
if isinstance(m, nn.Conv1d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def weights_init_Classifier(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_normal_(m.weight)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
@ -0,0 +1,298 @@
|
||||||
|
"""
|
||||||
|
@Author: miykah
|
||||||
|
@Email: miykah@163.com
|
||||||
|
@FileName: read_data_cwru.py
|
||||||
|
@DateTime: 2022/7/20 15:58
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import scipy.io as scio
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image as Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def dat_to_numpy(samples_num, sample_length):
|
||||||
|
'''根据每个类别需要的样本数和样本长度计算采样步长'''
|
||||||
|
stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1))
|
||||||
|
|
||||||
|
conditions = ['A', 'B', 'C']
|
||||||
|
for condition in conditions:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition
|
||||||
|
save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition
|
||||||
|
save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
|
||||||
|
|
||||||
|
if condition == 'A':
|
||||||
|
start_idx = 1
|
||||||
|
end_idx = 5
|
||||||
|
# start_idx = 13
|
||||||
|
# end_idx = 17
|
||||||
|
# start_idx = 25
|
||||||
|
# end_idx = 29
|
||||||
|
# start_idx = 9
|
||||||
|
# end_idx = 13
|
||||||
|
elif condition == 'B':
|
||||||
|
start_idx = 5
|
||||||
|
end_idx = 9
|
||||||
|
# start_idx = 17
|
||||||
|
# end_idx = 21
|
||||||
|
# start_idx = 29
|
||||||
|
# end_idx = 33
|
||||||
|
elif condition == 'C':
|
||||||
|
start_idx = 9
|
||||||
|
end_idx = 13
|
||||||
|
# start_idx = 21
|
||||||
|
# end_idx = 25
|
||||||
|
# start_idx = 33
|
||||||
|
# end_idx = 37
|
||||||
|
# start_idx = 25
|
||||||
|
# end_idx = 29
|
||||||
|
|
||||||
|
dir_names = os.listdir(root_dir)
|
||||||
|
print(dir_names)
|
||||||
|
'''
|
||||||
|
故障类型
|
||||||
|
['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速',
|
||||||
|
'平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速',
|
||||||
|
'平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速',
|
||||||
|
'平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速']
|
||||||
|
'''
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
for dir_name in dir_names:
|
||||||
|
data_list = []
|
||||||
|
for i in range(start_idx, end_idx):
|
||||||
|
if i < 10:
|
||||||
|
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat'
|
||||||
|
else:
|
||||||
|
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat'
|
||||||
|
data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3]
|
||||||
|
data_list.append(data)
|
||||||
|
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024)
|
||||||
|
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048)
|
||||||
|
data_one_cls = np.array(data_list).reshape(-1)
|
||||||
|
n = math.floor((204800 * 4 - sample_length) / stride + 1)
|
||||||
|
list = []
|
||||||
|
for i in range(n):
|
||||||
|
start = i * stride
|
||||||
|
end = i * stride + sample_length
|
||||||
|
list.append(data_one_cls[start: end])
|
||||||
|
data_one_cls = np.array(list).reshape(-1, 1, sample_length)
|
||||||
|
# 打乱数据
|
||||||
|
shuffle_ix = np.random.permutation(np.arange(len(data_one_cls)))
|
||||||
|
data_one_cls = data_one_cls[shuffle_ix]
|
||||||
|
print(data_one_cls.shape)
|
||||||
|
np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_signal_img(condition):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
pic = plt.figure(figsize=(12, 6), dpi=200)
|
||||||
|
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||||||
|
# plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
plt.rc('font',family='Times New Roman')
|
||||||
|
clses = ['(a)', '(b)', '(c)', '(d)']
|
||||||
|
for file_name, i, cls in zip(file_names, range(4), clses):
|
||||||
|
data = np.load(root_dir + '\\' + file_name)[0].reshape(-1)
|
||||||
|
print(data.shape, file_name)
|
||||||
|
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
|
||||||
|
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
|
||||||
|
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
|
||||||
|
plt.subplot(2, 2, i + 1)
|
||||||
|
plt.title(cls, fontsize=20)
|
||||||
|
plt.xlabel("Time (s)", fontsize=20)
|
||||||
|
plt.ylabel("Amplitude (m/s²)", fontsize=20)
|
||||||
|
plt.xlim(0, 0.16)
|
||||||
|
plt.ylim(-3, 3)
|
||||||
|
plt.yticks(np.array([-2, 0, 2]), fontsize=20)
|
||||||
|
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20)
|
||||||
|
plt.plot(data)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def draw_signal_img2():
|
||||||
|
conditions = ['A', 'B', 'C']
|
||||||
|
pic = plt.figure(figsize=(12, 6), dpi=600)
|
||||||
|
plt.text(x=1, y=3, s='A')
|
||||||
|
i = 1
|
||||||
|
for condition in conditions:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||||||
|
# plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
plt.rc('font',family='Times New Roman')
|
||||||
|
clses = ['(a)', '(b)', '(c)', '(d)', '(e)']
|
||||||
|
for file_name, cls in zip(file_names, clses):
|
||||||
|
data = np.load(root_dir + '\\' + file_name)[3].reshape(-1)
|
||||||
|
print(data.shape, file_name)
|
||||||
|
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
|
||||||
|
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
|
||||||
|
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
|
||||||
|
plt.subplot(3, 5, i)
|
||||||
|
plt.title(cls, fontsize=15)
|
||||||
|
plt.xlabel("Time (s)", fontsize=15)
|
||||||
|
plt.ylabel("Amplitude (m/s²)", fontsize=15)
|
||||||
|
plt.xlim(0, 0.16)
|
||||||
|
plt.ylim(-5, 5)
|
||||||
|
plt.yticks(np.array([-3, 0, 3]), fontsize=15)
|
||||||
|
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15)
|
||||||
|
plt.plot(data)
|
||||||
|
i += 1
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
'''得到源域数据和标签'''
|
||||||
|
def get_src_data(src_condition, shot):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
|
||||||
|
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
|
||||||
|
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
|
||||||
|
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
# print(file_names)
|
||||||
|
datas = []
|
||||||
|
labels = []
|
||||||
|
for i in range(600 // shot):
|
||||||
|
for file_name, cls in zip(file_names, range(5)):
|
||||||
|
data = np.load(root_dir + '\\' + file_name)
|
||||||
|
data_cut = data[shot * i: shot * (i + 1), :, :]
|
||||||
|
label_cut = np.array([cls] * shot, dtype='int64')
|
||||||
|
datas.append(data_cut)
|
||||||
|
labels.append(label_cut)
|
||||||
|
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||||
|
labels = np.array(labels).reshape(-1)
|
||||||
|
print(datas.shape, labels.shape)
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||||
|
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||||
|
|
||||||
|
'''得到没有标签的目标域数据'''
|
||||||
|
def get_tar_data(tar_condition):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
datas = []
|
||||||
|
for file_name in file_names:
|
||||||
|
data = np.load(root_dir + '\\' + file_name)
|
||||||
|
data_cut = data[0: 600, :, :]
|
||||||
|
datas.append(data_cut)
|
||||||
|
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||||
|
print("datas.shape: {}".format(datas.shape))
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||||
|
|
||||||
|
def get_tar_data2(tar_condition, shot):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
datas = []
|
||||||
|
labels = []
|
||||||
|
for i in range(600 // shot):
|
||||||
|
for file_name, cls in zip(file_names, range(5)):
|
||||||
|
data = np.load(root_dir + '\\' + file_name)
|
||||||
|
data_cut = data[shot * i: shot * (i + 1), :, :]
|
||||||
|
label_cut = np.array([cls] * shot, dtype='int64')
|
||||||
|
datas.append(data_cut)
|
||||||
|
labels.append(label_cut)
|
||||||
|
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||||
|
labels = np.array(labels).reshape(-1)
|
||||||
|
print(datas.shape, labels.shape)
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||||
|
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||||
|
|
||||||
|
'''得到验证集(每个类别100个样本)'''
|
||||||
|
def get_val_data(tar_condition):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
datas = []
|
||||||
|
labels = []
|
||||||
|
for file_name, cls in zip(file_names, range(5)):
|
||||||
|
data = np.load(root_dir + '\\' + file_name)
|
||||||
|
data_cut = data[600: 700, :, :]
|
||||||
|
label_cut = np.array([cls] * 100, dtype='int64')
|
||||||
|
datas.append(data_cut)
|
||||||
|
labels.append(label_cut)
|
||||||
|
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||||
|
labels = np.array(labels).reshape(-1)
|
||||||
|
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||||
|
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||||
|
|
||||||
|
'''得到测试集(100个样本)'''
|
||||||
|
def get_test_data(tar_condition):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||||
|
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
|
||||||
|
file_names = os.listdir(root_dir)
|
||||||
|
datas = []
|
||||||
|
labels = []
|
||||||
|
for file_name, cls in zip(file_names, range(5)):
|
||||||
|
data = np.load(root_dir + '\\' + file_name)
|
||||||
|
data_cut = data[700: , :, :]
|
||||||
|
label_cut = np.array([cls] * 100, dtype='int64')
|
||||||
|
datas.append(data_cut)
|
||||||
|
labels.append(label_cut)
|
||||||
|
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||||
|
labels = np.array(labels).reshape(-1)
|
||||||
|
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||||
|
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# dat_to_numpy(samples_num=800, sample_length=2048)
|
||||||
|
# draw_signal_img('C')
|
||||||
|
# draw_signal_img2()
|
||||||
|
# get_src_data('A', shot=10)
|
||||||
|
# get_src_data('B', shot=10)
|
||||||
|
# get_src_data('C', shot=10)
|
||||||
|
get_tar_data('A')
|
||||||
|
get_tar_data('B')
|
||||||
|
get_tar_data('C')
|
||||||
|
get_val_data('A')
|
||||||
|
get_val_data('B')
|
||||||
|
get_val_data('C')
|
||||||
|
get_test_data('A')
|
||||||
|
get_test_data('B')
|
||||||
|
get_test_data('C')
|
||||||
|
|
||||||
|
get_src_data('A', shot=10)
|
||||||
|
get_src_data('B', shot=10)
|
||||||
|
get_src_data('C', shot=10)
|
||||||
|
get_tar_data2('A', shot=10)
|
||||||
|
get_tar_data2('B', shot=10)
|
||||||
|
get_tar_data2('C', shot=10)
|
||||||
|
|
@ -0,0 +1,226 @@
|
||||||
|
"""
|
||||||
|
@Author: miykah
|
||||||
|
@Email: miykah@163.com
|
||||||
|
@FileName: test.py
|
||||||
|
@DateTime: 2022/7/9 14:15
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from DDS_GADAN.load_data import get_dataset
|
||||||
|
from DDS_GADAN.load_data import Nor_Dataset
|
||||||
|
from DDS_GADAN.model import Extractor, LabelClassifier
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def load_data(tar_condition):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
|
||||||
|
|
||||||
|
test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
|
||||||
|
test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
|
||||||
|
|
||||||
|
return test_data, test_label
|
||||||
|
|
||||||
|
def tsne_2d_generate(cls, data, labels, pic_title):
|
||||||
|
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data)
|
||||||
|
x, y = tsne2D[:, 0], tsne2D[:, 1]
|
||||||
|
pic = plt.figure()
|
||||||
|
ax1 = pic.add_subplot()
|
||||||
|
ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色,因为标签有9类
|
||||||
|
plt.title(pic_title)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def tsne_2d_generate1(cls, data, labels, pic_title):
|
||||||
|
parameters = {'figure.dpi': 600,
|
||||||
|
'figure.figsize': (4, 3),
|
||||||
|
'savefig.dpi': 600,
|
||||||
|
'xtick.direction': 'in',
|
||||||
|
'ytick.direction': 'in',
|
||||||
|
'xtick.labelsize': 10,
|
||||||
|
'ytick.labelsize': 10,
|
||||||
|
'legend.fontsize': 11.3,
|
||||||
|
}
|
||||||
|
plt.rcParams.update(parameters)
|
||||||
|
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||||
|
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data)
|
||||||
|
tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0)
|
||||||
|
tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min)
|
||||||
|
s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :]
|
||||||
|
pic = plt.figure()
|
||||||
|
# ax1 = pic.add_subplot()
|
||||||
|
plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色,因为标签有9类
|
||||||
|
plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色,因为标签有9类
|
||||||
|
plt.title(pic_title, fontsize=10)
|
||||||
|
# plt.xticks([])
|
||||||
|
# plt.yticks([])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
|
||||||
|
# # 画混淆矩阵
|
||||||
|
# confusion = confusion_matrix(true_labels, predict_labels)
|
||||||
|
# # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
|
||||||
|
# plt.figure(figsize=(6.4,6.4), dpi=100)
|
||||||
|
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
|
||||||
|
# # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
|
||||||
|
# indices = range(len(confusion))
|
||||||
|
# classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
|
||||||
|
# # for i in range(cls):
|
||||||
|
# # classes.append(str(i))
|
||||||
|
# # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||||
|
# # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||||
|
# # plt.yticks(indices, classes, rotation=45)
|
||||||
|
# plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||||
|
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||||
|
# plt.ylabel('Actual label')
|
||||||
|
# plt.xlabel('Predicted label')
|
||||||
|
# plt.title('confusion matrix')
|
||||||
|
# plt.show()
|
||||||
|
sns.set(font_scale=1.5)
|
||||||
|
parameters = {'figure.dpi': 600,
|
||||||
|
'figure.figsize': (5, 5),
|
||||||
|
'savefig.dpi': 600,
|
||||||
|
'xtick.direction': 'in',
|
||||||
|
'ytick.direction': 'in',
|
||||||
|
'xtick.labelsize': 20,
|
||||||
|
'ytick.labelsize': 20,
|
||||||
|
'legend.fontsize': 11.3,
|
||||||
|
}
|
||||||
|
plt.rcParams.update(parameters)
|
||||||
|
plt.figure()
|
||||||
|
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||||
|
# 画混淆矩阵
|
||||||
|
confusion = confusion_matrix(true_labels, predict_labels)
|
||||||
|
# confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
|
||||||
|
plt.figure()
|
||||||
|
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
|
||||||
|
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
|
||||||
|
indices = range(len(confusion))
|
||||||
|
classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
|
||||||
|
# for i in range(cls):
|
||||||
|
# classes.append(str(i))
|
||||||
|
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||||
|
# plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||||
|
# plt.yticks(indices, classes, rotation=45)
|
||||||
|
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||||
|
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||||
|
plt.ylabel('Actual label', fontsize=20)
|
||||||
|
plt.xlabel('Predicted label', fontsize=20)
|
||||||
|
# plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def test(cls, tar_condition, G_params_path, LC_params_path):
|
||||||
|
test_data, test_label = load_data(tar_condition)
|
||||||
|
test_dataset = Nor_Dataset(test_data, test_label)
|
||||||
|
|
||||||
|
batch_size = len(test_dataset)
|
||||||
|
|
||||||
|
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# 加载网络
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
G = Extractor().to(device)
|
||||||
|
LC = LabelClassifier(cls_num=cls).to(device)
|
||||||
|
G.load_state_dict(
|
||||||
|
torch.load(G_params_path, map_location=device)
|
||||||
|
)
|
||||||
|
LC.load_state_dict(
|
||||||
|
torch.load(LC_params_path, map_location=device)
|
||||||
|
)
|
||||||
|
# print(net)
|
||||||
|
# params_num = sum(param.numel() for param in net.parameters_bak())
|
||||||
|
# print('参数数量:{}'.format(params_num))
|
||||||
|
|
||||||
|
test_acc = 0.0
|
||||||
|
|
||||||
|
G.eval()
|
||||||
|
LC.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (data, label) in enumerate(test_loader):
|
||||||
|
data, label = data.to(device), label.to(device)
|
||||||
|
_, feature = G(data, data)
|
||||||
|
_, _, _, output = LC(feature, feature)
|
||||||
|
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
|
||||||
|
|
||||||
|
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
|
||||||
|
labels = label.cpu().numpy()
|
||||||
|
|
||||||
|
predictions = output.cpu().detach().numpy()
|
||||||
|
|
||||||
|
tsne_2d_generate(cls, predictions, labels, "output of neural network")
|
||||||
|
|
||||||
|
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
|
||||||
|
|
||||||
|
print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__()))
|
||||||
|
|
||||||
|
def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||||
|
else:
|
||||||
|
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||||
|
|
||||||
|
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
|
||||||
|
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
|
||||||
|
|
||||||
|
tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy')
|
||||||
|
tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy')
|
||||||
|
|
||||||
|
src_dataset = Nor_Dataset(src_data, src_label)
|
||||||
|
tar_dataset = Nor_Dataset(tar_data, tar_label)
|
||||||
|
src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True)
|
||||||
|
tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True)
|
||||||
|
|
||||||
|
# 加载网络
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
G = Extractor().to(device)
|
||||||
|
LC = LabelClassifier(cls_num=cls).to(device)
|
||||||
|
G.load_state_dict(
|
||||||
|
torch.load(G_params_path, map_location=device)
|
||||||
|
)
|
||||||
|
LC.load_state_dict(
|
||||||
|
torch.load(LC_params_path, map_location=device)
|
||||||
|
)
|
||||||
|
# print(net)
|
||||||
|
# params_num = sum(param.numel() for param in net.parameters_bak())
|
||||||
|
# print('参数数量:{}'.format(params_num))
|
||||||
|
|
||||||
|
test_acc = 0.0
|
||||||
|
|
||||||
|
G.eval()
|
||||||
|
LC.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)):
|
||||||
|
src_data, src_label = src_data.to(device), src_label.to(device)
|
||||||
|
tar_data, tar_label = tar_data.to(device), tar_label.to(device)
|
||||||
|
data = torch.concat((src_data, tar_data), dim=0)
|
||||||
|
label = torch.concat((src_label, tar_label), dim=0)
|
||||||
|
_, feature = G(data, data)
|
||||||
|
_, _, fc1, output = LC(feature, feature)
|
||||||
|
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
|
||||||
|
|
||||||
|
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
|
||||||
|
labels = label.cpu().numpy()
|
||||||
|
|
||||||
|
outputs = output.cpu().detach().numpy()
|
||||||
|
fc1_outputs = fc1.cpu().detach().numpy()
|
||||||
|
break
|
||||||
|
|
||||||
|
tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN")
|
||||||
|
|
||||||
|
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
|
||||||
|
|
||||||
|
print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__())))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# pass
|
||||||
|
test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
|
||||||
|
'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
|
||||||
|
|
||||||
|
# test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
|
||||||
|
# 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
|
||||||
|
|
@ -0,0 +1,357 @@
|
||||||
|
"""
|
||||||
|
@Author: miykah
|
||||||
|
@Email: miykah@163.com
|
||||||
|
@FileName: train.py
|
||||||
|
@DateTime: 2022/7/20 20:22
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from example.load_data import get_dataset, draw_signal_img
|
||||||
|
from example.model import Extractor, LabelClassifier, DomainClassifier, weights_init_Classifier, weights_init_Extractor
|
||||||
|
from example.loss import cal_mmdc_loss, BinaryCrossEntropyLoss, get_mmdc
|
||||||
|
import example.mmd as mmd
|
||||||
|
from example.test import test, test1
|
||||||
|
from scipy.spatial.distance import cdist
|
||||||
|
import math
|
||||||
|
|
||||||
|
def obtain_label(feature, output, bs):
|
||||||
|
with torch.no_grad():
|
||||||
|
all_fea = feature.reshape(bs, -1).float().cpu()
|
||||||
|
all_output = output.float().cpu()
|
||||||
|
# all_label = label.float()
|
||||||
|
all_output = nn.Softmax(dim=1)(all_output)
|
||||||
|
_, predict = torch.max(all_output, 1)
|
||||||
|
# accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
|
||||||
|
|
||||||
|
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
|
||||||
|
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
|
||||||
|
all_fea = all_fea.float().cpu().numpy()
|
||||||
|
|
||||||
|
K = all_output.size(1)
|
||||||
|
aff = all_output.float().cpu().numpy()
|
||||||
|
initc = aff.transpose().dot(all_fea)
|
||||||
|
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
|
||||||
|
dd = cdist(all_fea, initc, 'cosine')
|
||||||
|
pred_label = dd.argmin(axis=1)
|
||||||
|
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
|
||||||
|
|
||||||
|
for round in range(1):
|
||||||
|
aff = np.eye(K)[pred_label]
|
||||||
|
initc = aff.transpose().dot(all_fea)
|
||||||
|
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
|
||||||
|
dd = cdist(all_fea, initc, 'cosine')
|
||||||
|
pred_label = dd.argmin(axis=1)
|
||||||
|
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
|
||||||
|
|
||||||
|
# log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
|
||||||
|
# args.out_file.write(log_str + '\n')
|
||||||
|
# args.out_file.flush()
|
||||||
|
# print(log_str + '\n')
|
||||||
|
return pred_label.astype('int')
|
||||||
|
|
||||||
|
|
||||||
|
def train(device, src_condition, tar_condition, cls, epochs, bs, shot, lr, patience, gamma, miu):
|
||||||
|
'''特征提取器'''
|
||||||
|
G = Extractor()
|
||||||
|
G.apply(weights_init_Extractor)
|
||||||
|
G.to(device)
|
||||||
|
'''标签分类器'''
|
||||||
|
LC = LabelClassifier(cls_num=cls)
|
||||||
|
LC.apply(weights_init_Classifier)
|
||||||
|
LC.to(device)
|
||||||
|
'''域分类器'''
|
||||||
|
DC = DomainClassifier()
|
||||||
|
DC.apply(weights_init_Classifier)
|
||||||
|
DC.to(device)
|
||||||
|
|
||||||
|
'''得到数据集'''
|
||||||
|
src_dataset, tar_dataset, val_dataset, test_dataset \
|
||||||
|
= get_dataset(src_condition, tar_condition)
|
||||||
|
'''DataLoader'''
|
||||||
|
src_loader = DataLoader(dataset=src_dataset, batch_size=bs, shuffle=False, drop_last=True)
|
||||||
|
tar_loader = DataLoader(dataset=tar_dataset, batch_size=bs, shuffle=True, drop_last=True)
|
||||||
|
val_loader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, drop_last=False)
|
||||||
|
test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False)
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss().to(device)
|
||||||
|
BCE = BinaryCrossEntropyLoss().to(device)
|
||||||
|
|
||||||
|
optimizer_g = torch.optim.Adam(G.parameters(), lr=lr)
|
||||||
|
optimizer_lc = torch.optim.Adam(LC.parameters(), lr=lr)
|
||||||
|
optimizer_dc = torch.optim.Adam(DC.parameters(), lr=lr)
|
||||||
|
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=patience)
|
||||||
|
scheduler_lc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_lc, mode="min", factor=0.5, patience=patience)
|
||||||
|
scheduler_dc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_dc, mode="min", factor=0.5, patience=patience)
|
||||||
|
|
||||||
|
def zero_grad_all():
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
optimizer_lc.zero_grad()
|
||||||
|
optimizer_dc.zero_grad()
|
||||||
|
|
||||||
|
src_acc_list = []
|
||||||
|
train_loss_list = []
|
||||||
|
val_acc_list = []
|
||||||
|
val_loss_list = []
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
src_acc = 0.0
|
||||||
|
train_loss = 0.0
|
||||||
|
val_acc = 0.0
|
||||||
|
val_loss = 0.0
|
||||||
|
|
||||||
|
G.train()
|
||||||
|
LC.train()
|
||||||
|
DC.train()
|
||||||
|
|
||||||
|
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data)) in zip(enumerate(src_loader), enumerate(tar_loader)):
|
||||||
|
src_data, src_label = src_data.to(device), src_label.to(device)
|
||||||
|
tar_data = tar_data.to(device)
|
||||||
|
zero_grad_all()
|
||||||
|
|
||||||
|
T1 = (int)(0.2 * epochs)
|
||||||
|
T2 = (int)(0.5 * epochs)
|
||||||
|
|
||||||
|
src_feature, tar_feature = G(src_data, tar_data)
|
||||||
|
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
||||||
|
if epoch < T1:
|
||||||
|
pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||||
|
else:
|
||||||
|
pseudo_label = torch.tensor(obtain_label(tar_feature, tar_output, bs), dtype=torch.int64).cuda()
|
||||||
|
|
||||||
|
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||||
|
# mmdc权重
|
||||||
|
if epoch < T1:
|
||||||
|
miu_f = 0
|
||||||
|
# elif epoch > T1 and epoch < T2:
|
||||||
|
# miu_f = miu * (epoch - T1) / (T2 - T1)
|
||||||
|
else:
|
||||||
|
miu_f = miu
|
||||||
|
|
||||||
|
# 源数据的交叉熵损失
|
||||||
|
loss_src = criterion(src_output, src_label)
|
||||||
|
# 目标数据的伪标签交叉熵损失
|
||||||
|
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
||||||
|
# mmd损失
|
||||||
|
loss_mmdm = mmd.mmd_rbf_noaccelerate(src_data_mmd1, tar_data_mmd1)
|
||||||
|
if epoch < T1:
|
||||||
|
loss_mmdc = 0
|
||||||
|
else:
|
||||||
|
loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
|
||||||
|
# loss_mmdc = cal_mmdc_loss(src_data_mmd1, src_output, tar_data_mmd1, tar_output)
|
||||||
|
# loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
|
||||||
|
# loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
|
||||||
|
|
||||||
|
# 伪标签损失的权重
|
||||||
|
# if epoch < T1:
|
||||||
|
# beta_f = 0
|
||||||
|
# elif epoch > T1 and epoch < T2:
|
||||||
|
# beta_f = beta * (epoch - T1) / (T2 - T1)
|
||||||
|
# else:
|
||||||
|
# beta_f = beta
|
||||||
|
|
||||||
|
p = epoch / epochs
|
||||||
|
lamda = (2 / (1 + math.exp(-10 * p))) - 1
|
||||||
|
# gamma = (2 / (1 + math.exp(-10 * p))) - 1
|
||||||
|
# miu_f = (2 / (1 + math.exp(-10 * p))) - 1
|
||||||
|
|
||||||
|
loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
|
||||||
|
|
||||||
|
loss_G_LC = loss_src + gamma * loss_jmmd
|
||||||
|
# loss_G_LC = loss_src + beta_f * loss_tar_pseudo + gamma * loss_jmmd
|
||||||
|
loss_G_LC.backward()
|
||||||
|
optimizer_g.step()
|
||||||
|
optimizer_lc.step()
|
||||||
|
zero_grad_all()
|
||||||
|
#-----------------------------------------------
|
||||||
|
# 对抗域适应的损失
|
||||||
|
src_feature, tar_feature = G(src_data, tar_data)
|
||||||
|
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
||||||
|
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||||
|
# 源数据的交叉熵损失
|
||||||
|
loss_src = criterion(src_output, src_label)
|
||||||
|
# 目标数据的伪标签交叉熵损失
|
||||||
|
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
||||||
|
|
||||||
|
gradient_src = \
|
||||||
|
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
|
||||||
|
only_inputs=True)[0]
|
||||||
|
gradient_tar = \
|
||||||
|
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
|
||||||
|
only_inputs=True)[0]
|
||||||
|
gradients_adv = torch.cat((gradient_src, gradient_tar), dim=0)
|
||||||
|
|
||||||
|
domain_label_reverse = DC(gradients_adv, reverse=True)
|
||||||
|
loss_adv_r = BCE(domain_label_reverse)
|
||||||
|
loss_G_adv = lamda * loss_adv_r
|
||||||
|
# 更新特征提取器G参数
|
||||||
|
loss_G_adv.backward()
|
||||||
|
optimizer_g.step()
|
||||||
|
zero_grad_all()
|
||||||
|
#---------------------------------------------------------------------
|
||||||
|
src_feature, tar_feature = G(src_data, tar_data)
|
||||||
|
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
||||||
|
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||||
|
# 源数据的交叉熵损失
|
||||||
|
loss_src = criterion(src_output, src_label)
|
||||||
|
# 目标数据的伪标签交叉熵损失
|
||||||
|
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
||||||
|
|
||||||
|
gradient_src = \
|
||||||
|
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
|
||||||
|
only_inputs=True)[0]
|
||||||
|
gradient_tar = \
|
||||||
|
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
|
||||||
|
only_inputs=True)[0]
|
||||||
|
|
||||||
|
gradients = torch.cat((gradient_src, gradient_tar), dim=0)
|
||||||
|
domain_label = DC(gradients, reverse=False)
|
||||||
|
loss_adv = BCE(domain_label)
|
||||||
|
loss_DC = lamda * loss_adv
|
||||||
|
# 更新域分类器的参数
|
||||||
|
loss_DC.backward()
|
||||||
|
optimizer_dc.step()
|
||||||
|
zero_grad_all()
|
||||||
|
|
||||||
|
src_acc += np.sum(np.argmax(src_output.cpu().detach().numpy(), axis=1) == src_label.cpu().numpy())
|
||||||
|
train_loss += (loss_G_LC + loss_G_adv + loss_DC).item()
|
||||||
|
|
||||||
|
G.eval()
|
||||||
|
LC.eval()
|
||||||
|
DC.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (val_data, val_label) in enumerate(val_loader):
|
||||||
|
val_data, val_label = val_data.to(device), val_label.to(device)
|
||||||
|
_, val_feature = G(val_data, val_data)
|
||||||
|
_, _, _, val_output = LC(val_feature, val_feature)
|
||||||
|
loss = criterion(val_output, val_label)
|
||||||
|
|
||||||
|
val_acc += np.sum(np.argmax(val_output.cpu().detach().numpy(), axis=1) == val_label.cpu().numpy())
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
scheduler_g.step(val_loss)
|
||||||
|
scheduler_lc.step(val_loss)
|
||||||
|
scheduler_dc.step(val_loss)
|
||||||
|
|
||||||
|
print("[{:03d}/{:03d}] {:2.2f} sec(s) src_acc: {:3.6f} train_loss: {:3.9f} | val_acc: {:3.6f} val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
|
||||||
|
epoch + 1, epochs, time.time() - epoch_start_time, \
|
||||||
|
src_acc / src_dataset.__len__(), train_loss / src_dataset.__len__(),
|
||||||
|
val_acc / val_dataset.__len__(), val_loss / val_dataset.__len__(),
|
||||||
|
optimizer_g.state_dict()['param_groups'][0]['lr']))
|
||||||
|
|
||||||
|
# 保存在验证集上loss最小的模型
|
||||||
|
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
||||||
|
# 如果精度大于最高精度,则保存
|
||||||
|
if val_acc_list.__len__() > 0 :
|
||||||
|
# if (val_acc / val_dataset.__len__()) >= max(val_acc_list):
|
||||||
|
if (val_acc / val_dataset.__len__()) > max(val_acc_list) or (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
||||||
|
print("保存模型最佳模型成功")
|
||||||
|
G_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/G"
|
||||||
|
LC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/LC"
|
||||||
|
DC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/DC"
|
||||||
|
if not os.path.exists(G_path):
|
||||||
|
os.makedirs(G_path)
|
||||||
|
if not os.path.exists(LC_path):
|
||||||
|
os.makedirs(LC_path)
|
||||||
|
if not os.path.exists(DC_path):
|
||||||
|
os.makedirs(DC_path)
|
||||||
|
# 保存模型参数
|
||||||
|
torch.save(G.state_dict(), G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
||||||
|
torch.save(LC.state_dict(), LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
||||||
|
torch.save(DC.state_dict(), DC_path + "/DC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
||||||
|
|
||||||
|
src_acc_list.append(src_acc / src_dataset.__len__())
|
||||||
|
train_loss_list.append(train_loss / src_dataset.__len__())
|
||||||
|
val_acc_list.append(val_acc / val_dataset.__len__())
|
||||||
|
val_loss_list.append(val_loss / val_dataset.__len__())
|
||||||
|
|
||||||
|
'''保存的模型参数的路径'''
|
||||||
|
G_params_path = G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
|
||||||
|
LC_params_path = LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
|
||||||
|
|
||||||
|
|
||||||
|
from matplotlib import rcParams
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"font.family": 'Times New Roman', # 设置字体类型
|
||||||
|
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
||||||
|
"axes.labelsize": 13
|
||||||
|
}
|
||||||
|
rcParams.update(config)
|
||||||
|
|
||||||
|
# pic1 = plt.figure(figsize=(8, 6), dpi=200)
|
||||||
|
# plt.subplot(211)
|
||||||
|
# plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc')
|
||||||
|
# plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc')
|
||||||
|
# plt.ylim(0.3, 1.0) # 设置y轴范围
|
||||||
|
# plt.title('Training & Validation accuracy')
|
||||||
|
# plt.xlabel('epoch')
|
||||||
|
# plt.ylabel('accuracy')
|
||||||
|
# plt.legend(loc='lower right')
|
||||||
|
# plt.grid(alpha=0.4)
|
||||||
|
#
|
||||||
|
# plt.subplot(212)
|
||||||
|
# plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss')
|
||||||
|
# plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss')
|
||||||
|
# plt.ylim(0, 0.08) # 设置y轴范围
|
||||||
|
# plt.title('Training & Validation loss')
|
||||||
|
# plt.xlabel('epoch')
|
||||||
|
# plt.ylabel('loss')
|
||||||
|
# plt.legend(loc='upper right')
|
||||||
|
# plt.grid(alpha=0.4)
|
||||||
|
|
||||||
|
pic1 = plt.figure(figsize=(12, 6), dpi=200)
|
||||||
|
|
||||||
|
plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='Training Loss')
|
||||||
|
plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='Validation Loss')
|
||||||
|
plt.ylim(0, 0.08) # 设置y轴范围
|
||||||
|
plt.title('Training & Validation loss')
|
||||||
|
plt.xlabel('epoch')
|
||||||
|
plt.ylabel('loss')
|
||||||
|
plt.legend(loc='upper right')
|
||||||
|
plt.grid(alpha=0.4)
|
||||||
|
|
||||||
|
# 获取当前时间戳
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
# 将时间戳转换为字符串
|
||||||
|
timestamp_str = str(timestamp)
|
||||||
|
plt.savefig(timestamp_str, dpi=200)
|
||||||
|
|
||||||
|
return G_params_path, LC_params_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
begin = time.time()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
seed = 2
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
src_condition = 'A'
|
||||||
|
tar_condition = 'B'
|
||||||
|
'''训练'''
|
||||||
|
G_params_path, LC_params_path = train(device, src_condition, tar_condition, cls=5,
|
||||||
|
epochs=200, bs=50, shot=10, lr=0.002, patience=40, gamma=1, miu=0.5)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
'''测试'''
|
||||||
|
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
|
||||||
|
test(5, tar_condition, G_params_path, LC_params_path)
|
||||||
|
|
||||||
|
print("训练耗时:{:3.2f}s".format(end - begin))
|
||||||
Loading…
Reference in New Issue