pytorch更新

This commit is contained in:
kevinding1125 2023-11-09 21:42:09 +08:00
parent 564ba3f669
commit 2c3e6c25a8
13 changed files with 1332 additions and 0 deletions

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/9 21:32
@Usage :
@Desc :
'''

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/9 21:33
@Usage :
@Desc :
'''

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/9 21:34
@Usage :
@Desc :
'''

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/9 21:34
@Usage :
@Desc :
'''

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/9 21:34
@Usage :
@Desc :
'''

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/9 21:28
@Usage :
@Desc :
'''

View File

@ -0,0 +1,111 @@
"""
@Author: miykah
@Email: miykah@163.com
@FileName: load_data.py
@DateTime: 2022/7/20 16:40
"""
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
'''正常Dataset类'''
class Nor_Dataset(Dataset):
def __init__(self, datas, labels=None):
self.datas = torch.tensor(datas)
if labels is not None:
self.labels = torch.tensor(labels)
else:
self.labels = None
def __getitem__(self, index):
data = self.datas[index]
if self.labels is not None:
label = self.labels[index]
return data, label
return data
def __len__(self):
return len(self.datas)
'''未标记目标数据的Dataset类'''
class Tar_U_Dataset(Dataset):
def __init__(self, datas):
self.datas = torch.tensor(datas)
def __getitem__(self, index):
data = self.datas[index]
data_bar = data.clone().detach()
data_bar2 = data.clone().detach()
mu = 0
sigma = 0.1
'''对未标记目标数据加噪得到data'和data'' '''
for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])):
data_bar[0, i] += random.gauss(mu, sigma)
data_bar2[0, j] += random.gauss(mu, sigma)
return data, data_bar, data_bar2
def __len__(self):
return len(self.datas)
def draw_signal_img(data, data_bar, data_bar2):
pic = plt.figure(figsize=(12, 6), dpi=100)
plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
plt.rcParams['axes.unicode_minus'] = False
plt.subplot(3, 1, 1)
plt.plot(data, 'b')
plt.subplot(3, 1, 2)
plt.plot(data_bar, 'b')
plt.subplot(3, 1, 3)
plt.plot(data_bar2, 'b')
plt.show()
def get_dataset(src_condition, tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy')
val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy')
test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy')
src_dataset = Nor_Dataset(src_data, src_label)
tar_dataset = Nor_Dataset(tar_data)
val_dataset = Nor_Dataset(val_data, val_label)
test_dataset = Nor_Dataset(test_data, test_label)
return src_dataset, tar_dataset, val_dataset, test_dataset
if __name__ == '__main__':
# pass
# tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy")
# tar_dataset = Nor_Dataset(datas=tar_data)
# tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False)
#
# for batch_idx, (data) in enumerate(tar_loader):
# print(data[0][0])
# # print(data_bar[0][0])
# # print(data_bar2[0][0])
# if batch_idx == 0:
# break
src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy")
src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy")
src_dataset = Nor_Dataset(datas=src_data, labels=src_label)
src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False)
for batch_idx, (data, label) in enumerate(src_loader):
print(data[10][0])
print(label[10])
# print(data_bar[0][0])
# print(data_bar2[0][0])
if batch_idx == 0:
break

View File

@ -0,0 +1,127 @@
"""
@Author: miykah
@Email: miykah@163.com
@FileName: loss.py
@DateTime: 2022/7/21 14:31
"""
import torch
import numpy as np
import torch.nn.functional as functional
import torch.nn as nn
import example.mmd as mmd
'''
计算条件概率分布的差异
mmdc = mmd_condition
'''
def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2):
src_cls_10 = src_data_mmd1[3 * 0: 3 * 1]
src_cls_11 = src_data_mmd1[3 * 1: 3 * 2]
src_cls_12 = src_data_mmd1[3 * 2: 3 * 3]
src_cls_13 = src_data_mmd1[3 * 3: 3 * 4]
src_cls_14 = src_data_mmd1[3 * 4: 3 * 5]
src_cls_15 = src_data_mmd1[3 * 5: 3 * 6]
src_cls_16 = src_data_mmd1[3 * 6: 3 * 7]
src_cls_17 = src_data_mmd1[3 * 7: 3 * 8]
src_cls_18 = src_data_mmd1[3 * 8: 3 * 9]
src_cls_19 = src_data_mmd1[3 * 9: 3 * 10]
tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1]
tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2]
tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3]
tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4]
tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5]
tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6]
tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7]
tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8]
tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9]
tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10]
src_cls_20 = src_data_mmd2[3 * 0: 3 * 1]
src_cls_21 = src_data_mmd2[3 * 1: 3 * 2]
src_cls_22 = src_data_mmd2[3 * 2: 3 * 3]
src_cls_23 = src_data_mmd2[3 * 3: 3 * 4]
src_cls_24 = src_data_mmd2[3 * 4: 3 * 5]
src_cls_25 = src_data_mmd2[3 * 5: 3 * 6]
src_cls_26 = src_data_mmd2[3 * 6: 3 * 7]
src_cls_27 = src_data_mmd2[3 * 7: 3 * 8]
src_cls_28 = src_data_mmd2[3 * 8: 3 * 9]
src_cls_29 = src_data_mmd2[3 * 9: 3 * 10]
tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1]
tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2]
tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3]
tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4]
tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5]
tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6]
tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7]
tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8]
tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9]
tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10]
mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10)
mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11)
mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12)
mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13)
mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14)
mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15)
mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16)
mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17)
mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18)
mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19)
mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20)
mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21)
mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22)
mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23)
mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24)
mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25)
mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26)
mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27)
mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28)
mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29)
mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19
mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29
return (mmdc2) / 10
# return (mmdc1 + mmdc2) / 20
'''得到源域每类特征用于计算mmdc'''
def get_src_mean_feature(src_feature, shot, cls):
src_feature_list = []
for i in range(cls):
src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0)
src_feature_list.append(src_feature_cls)
return src_feature_list
def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls):
src_feature_list = get_src_mean_feature(src_feature, shot, cls)
pseudo_label = tar_pseudo_label.cpu().detach().numpy()
mmdc = 0.0
for i in range(batch_size):
# mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
return mmdc / batch_size
class BCE(nn.Module):
eps = 1e-7
def forward(self, prob1, prob2, simi):
P = prob1.mul_(prob2)
P = P.sum(1)
P.mul_(simi).add_(simi.eq(-1).type_as(P))
neglogP = -P.add_(BCE.eps).log_()
return neglogP.mean()
class BinaryCrossEntropyLoss(nn.Module):
""" Construct binary cross-entropy loss."""
eps = 1e-7
def forward(self, prob):
# ds = torch.ones([bs, 1]).to(device) # domain label for source
# dt = torch.zeros([bs, 1]).to(device) # domain label for target
# di = torch.cat((ds, dt), dim=0).to(device)
# neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps))
neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps))
return neglogP.mean()

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python
# encoding: utf-8
import torch
# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)
def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss = 0
for i in range(batch_size):
s1, s2 = i, (i+1)%batch_size
t1, t2 = s1+batch_size, s2+batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
return loss / float(batch_size)
def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
XX = kernels[:batch_size, :batch_size]
YY = kernels[batch_size:, batch_size:]
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY -YX)
return loss

View File

@ -0,0 +1,109 @@
"""
@Author: miykah
@Email: miykah@163.com
@FileName: model.py
@DateTime: 2022/7/20 21:18
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class GradReverse(Function):
def __init__(self):
self.lambd = 1.0
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * - 1.0)
def grad_reverse(x):
return GradReverse.apply(x)
# return GradReverse(lambd)(x)
'''特征提取器'''
class Extractor(nn.Module):
def __init__(self):
super(Extractor, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2) # (32 * 1024)
)
self.conv2 = nn.Sequential(
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2) # (32 * 512)
)
self.conv3 = nn.Sequential(
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2) # (32 * 256)
)
def forward(self, src_data, tar_data):
src_data = self.conv1(src_data)
src_data = self.conv2(src_data)
src_feature = self.conv3(src_data)
tar_data = self.conv1(tar_data)
tar_data = self.conv2(tar_data)
tar_feature = self.conv3(tar_data)
return src_feature, tar_feature
'''标签分类器'''
class LabelClassifier(nn.Module):
def __init__(self, cls_num):
super(LabelClassifier, self).__init__()
self.fc1 = nn.Sequential(
nn.Flatten(), # (8192,)
nn.Linear(in_features=8192, out_features=256),
)
self.fc2 = nn.Sequential(
nn.ReLU(),
nn.Linear(in_features=256, out_features=cls_num)
)
def forward(self, src_feature, tar_feature):
src_data_mmd1 = self.fc1(src_feature)
src_output = self.fc2(src_data_mmd1)
tar_data_mmd1 = self.fc1(tar_feature)
tar_output = self.fc2(tar_data_mmd1)
return src_data_mmd1, src_output, tar_data_mmd1, tar_output
'''分类器'''
class DomainClassifier(nn.Module):
def __init__(self, temp=0.05):
super(DomainClassifier, self).__init__()
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=8192, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=1),
nn.Sigmoid()
)
self.temp = temp
def forward(self, x, reverse=False):
if reverse:
x = grad_reverse(x)
output = self.fc(x)
return output
'''初始化网络权重'''
def weights_init_Extractor(m):
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
nn.init.constant_(m.bias, 0)
def weights_init_Classifier(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)

View File

@ -0,0 +1,298 @@
"""
@Author: miykah
@Email: miykah@163.com
@FileName: read_data_cwru.py
@DateTime: 2022/7/20 15:58
"""
import math
import os
import scipy.io as scio
import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt
import torch
def dat_to_numpy(samples_num, sample_length):
'''根据每个类别需要的样本数和样本长度计算采样步长'''
stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1))
conditions = ['A', 'B', 'C']
for condition in conditions:
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
if condition == 'A':
start_idx = 1
end_idx = 5
# start_idx = 13
# end_idx = 17
# start_idx = 25
# end_idx = 29
# start_idx = 9
# end_idx = 13
elif condition == 'B':
start_idx = 5
end_idx = 9
# start_idx = 17
# end_idx = 21
# start_idx = 29
# end_idx = 33
elif condition == 'C':
start_idx = 9
end_idx = 13
# start_idx = 21
# end_idx = 25
# start_idx = 33
# end_idx = 37
# start_idx = 25
# end_idx = 29
dir_names = os.listdir(root_dir)
print(dir_names)
'''
故障类型
['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速',
'平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速',
'平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速',
'平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速']
'''
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for dir_name in dir_names:
data_list = []
for i in range(start_idx, end_idx):
if i < 10:
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat'
else:
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat'
data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3]
data_list.append(data)
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024)
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048)
data_one_cls = np.array(data_list).reshape(-1)
n = math.floor((204800 * 4 - sample_length) / stride + 1)
list = []
for i in range(n):
start = i * stride
end = i * stride + sample_length
list.append(data_one_cls[start: end])
data_one_cls = np.array(list).reshape(-1, 1, sample_length)
# 打乱数据
shuffle_ix = np.random.permutation(np.arange(len(data_one_cls)))
data_one_cls = data_one_cls[shuffle_ix]
print(data_one_cls.shape)
np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls)
def draw_signal_img(condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
else:
root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
file_names = os.listdir(root_dir)
pic = plt.figure(figsize=(12, 6), dpi=200)
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
# plt.rcParams['axes.unicode_minus'] = False
plt.rc('font',family='Times New Roman')
clses = ['(a)', '(b)', '(c)', '(d)']
for file_name, i, cls in zip(file_names, range(4), clses):
data = np.load(root_dir + '\\' + file_name)[0].reshape(-1)
print(data.shape, file_name)
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
plt.subplot(2, 2, i + 1)
plt.title(cls, fontsize=20)
plt.xlabel("Time (s)", fontsize=20)
plt.ylabel("Amplitude (m/s²)", fontsize=20)
plt.xlim(0, 0.16)
plt.ylim(-3, 3)
plt.yticks(np.array([-2, 0, 2]), fontsize=20)
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20)
plt.plot(data)
plt.show()
def draw_signal_img2():
conditions = ['A', 'B', 'C']
pic = plt.figure(figsize=(12, 6), dpi=600)
plt.text(x=1, y=3, s='A')
i = 1
for condition in conditions:
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
file_names = os.listdir(root_dir)
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
# plt.rcParams['axes.unicode_minus'] = False
plt.rc('font',family='Times New Roman')
clses = ['(a)', '(b)', '(c)', '(d)', '(e)']
for file_name, cls in zip(file_names, clses):
data = np.load(root_dir + '\\' + file_name)[3].reshape(-1)
print(data.shape, file_name)
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
plt.subplot(3, 5, i)
plt.title(cls, fontsize=15)
plt.xlabel("Time (s)", fontsize=15)
plt.ylabel("Amplitude (m/s²)", fontsize=15)
plt.xlim(0, 0.16)
plt.ylim(-5, 5)
plt.yticks(np.array([-3, 0, 3]), fontsize=15)
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15)
plt.plot(data)
i += 1
plt.show()
'''得到源域数据和标签'''
def get_src_data(src_condition, shot):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
file_names = os.listdir(root_dir)
# print(file_names)
datas = []
labels = []
for i in range(600 // shot):
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[shot * i: shot * (i + 1), :, :]
label_cut = np.array([cls] * shot, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print(datas.shape, labels.shape)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
'''得到没有标签的目标域数据'''
def get_tar_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
file_names = os.listdir(root_dir)
datas = []
for file_name in file_names:
data = np.load(root_dir + '\\' + file_name)
data_cut = data[0: 600, :, :]
datas.append(data_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
print("datas.shape: {}".format(datas.shape))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
def get_tar_data2(tar_condition, shot):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
file_names = os.listdir(root_dir)
datas = []
labels = []
for i in range(600 // shot):
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[shot * i: shot * (i + 1), :, :]
label_cut = np.array([cls] * shot, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print(datas.shape, labels.shape)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
'''得到验证集每个类别100个样本'''
def get_val_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
file_names = os.listdir(root_dir)
datas = []
labels = []
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[600: 700, :, :]
label_cut = np.array([cls] * 100, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
'''得到测试集100个样本'''
def get_test_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
file_names = os.listdir(root_dir)
datas = []
labels = []
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[700: , :, :]
label_cut = np.array([cls] * 100, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
if __name__ == '__main__':
# dat_to_numpy(samples_num=800, sample_length=2048)
# draw_signal_img('C')
# draw_signal_img2()
# get_src_data('A', shot=10)
# get_src_data('B', shot=10)
# get_src_data('C', shot=10)
get_tar_data('A')
get_tar_data('B')
get_tar_data('C')
get_val_data('A')
get_val_data('B')
get_val_data('C')
get_test_data('A')
get_test_data('B')
get_test_data('C')
get_src_data('A', shot=10)
get_src_data('B', shot=10)
get_src_data('C', shot=10)
get_tar_data2('A', shot=10)
get_tar_data2('B', shot=10)
get_tar_data2('C', shot=10)

View File

@ -0,0 +1,226 @@
"""
@Author: miykah
@Email: miykah@163.com
@FileName: test.py
@DateTime: 2022/7/9 14:15
"""
import numpy as np
import torch
from torch.utils.data import DataLoader
from DDS_GADAN.load_data import get_dataset
from DDS_GADAN.load_data import Nor_Dataset
from DDS_GADAN.model import Extractor, LabelClassifier
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
def load_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
return test_data, test_label
def tsne_2d_generate(cls, data, labels, pic_title):
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data)
x, y = tsne2D[:, 0], tsne2D[:, 1]
pic = plt.figure()
ax1 = pic.add_subplot()
ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色因为标签有9类
plt.title(pic_title)
plt.show()
def tsne_2d_generate1(cls, data, labels, pic_title):
parameters = {'figure.dpi': 600,
'figure.figsize': (4, 3),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.rc('font', family='Times New Roman') # 全局字体样式
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data)
tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0)
tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min)
s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :]
pic = plt.figure()
# ax1 = pic.add_subplot()
plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色因为标签有9类
plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色因为标签有9类
plt.title(pic_title, fontsize=10)
# plt.xticks([])
# plt.yticks([])
plt.show()
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
# # 画混淆矩阵
# confusion = confusion_matrix(true_labels, predict_labels)
# # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
# plt.figure(figsize=(6.4,6.4), dpi=100)
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
# # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
# indices = range(len(confusion))
# classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
# # for i in range(cls):
# # classes.append(str(i))
# # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# # plt.yticks(indices, classes, rotation=45)
# plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
# plt.ylabel('Actual label')
# plt.xlabel('Predicted label')
# plt.title('confusion matrix')
# plt.show()
sns.set(font_scale=1.5)
parameters = {'figure.dpi': 600,
'figure.figsize': (5, 5),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式
# 画混淆矩阵
confusion = confusion_matrix(true_labels, predict_labels)
# confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
plt.figure()
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
indices = range(len(confusion))
classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
# for i in range(cls):
# classes.append(str(i))
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices, classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
# plt.tight_layout()
plt.show()
def test(cls, tar_condition, G_params_path, LC_params_path):
test_data, test_label = load_data(tar_condition)
test_dataset = Nor_Dataset(test_data, test_label)
batch_size = len(test_dataset)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
# 加载网络
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G = Extractor().to(device)
LC = LabelClassifier(cls_num=cls).to(device)
G.load_state_dict(
torch.load(G_params_path, map_location=device)
)
LC.load_state_dict(
torch.load(LC_params_path, map_location=device)
)
# print(net)
# params_num = sum(param.numel() for param in net.parameters_bak())
# print('参数数量:{}'.format(params_num))
test_acc = 0.0
G.eval()
LC.eval()
with torch.no_grad():
for batch_idx, (data, label) in enumerate(test_loader):
data, label = data.to(device), label.to(device)
_, feature = G(data, data)
_, _, _, output = LC(feature, feature)
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
labels = label.cpu().numpy()
predictions = output.cpu().detach().numpy()
tsne_2d_generate(cls, predictions, labels, "output of neural network")
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__()))
def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy')
tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy')
src_dataset = Nor_Dataset(src_data, src_label)
tar_dataset = Nor_Dataset(tar_data, tar_label)
src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True)
tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True)
# 加载网络
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G = Extractor().to(device)
LC = LabelClassifier(cls_num=cls).to(device)
G.load_state_dict(
torch.load(G_params_path, map_location=device)
)
LC.load_state_dict(
torch.load(LC_params_path, map_location=device)
)
# print(net)
# params_num = sum(param.numel() for param in net.parameters_bak())
# print('参数数量:{}'.format(params_num))
test_acc = 0.0
G.eval()
LC.eval()
with torch.no_grad():
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)):
src_data, src_label = src_data.to(device), src_label.to(device)
tar_data, tar_label = tar_data.to(device), tar_label.to(device)
data = torch.concat((src_data, tar_data), dim=0)
label = torch.concat((src_label, tar_label), dim=0)
_, feature = G(data, data)
_, _, fc1, output = LC(feature, feature)
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
labels = label.cpu().numpy()
outputs = output.cpu().detach().numpy()
fc1_outputs = fc1.cpu().detach().numpy()
break
tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN")
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__())))
if __name__ == '__main__':
# pass
test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
# test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
# 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')

View File

@ -0,0 +1,357 @@
"""
@Author: miykah
@Email: miykah@163.com
@FileName: train.py
@DateTime: 2022/7/20 20:22
"""
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from example.load_data import get_dataset, draw_signal_img
from example.model import Extractor, LabelClassifier, DomainClassifier, weights_init_Classifier, weights_init_Extractor
from example.loss import cal_mmdc_loss, BinaryCrossEntropyLoss, get_mmdc
import example.mmd as mmd
from example.test import test, test1
from scipy.spatial.distance import cdist
import math
def obtain_label(feature, output, bs):
with torch.no_grad():
all_fea = feature.reshape(bs, -1).float().cpu()
all_output = output.float().cpu()
# all_label = label.float()
all_output = nn.Softmax(dim=1)(all_output)
_, predict = torch.max(all_output, 1)
# accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
all_fea = all_fea.float().cpu().numpy()
K = all_output.size(1)
aff = all_output.float().cpu().numpy()
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
dd = cdist(all_fea, initc, 'cosine')
pred_label = dd.argmin(axis=1)
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
for round in range(1):
aff = np.eye(K)[pred_label]
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
dd = cdist(all_fea, initc, 'cosine')
pred_label = dd.argmin(axis=1)
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
# log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
# args.out_file.write(log_str + '\n')
# args.out_file.flush()
# print(log_str + '\n')
return pred_label.astype('int')
def train(device, src_condition, tar_condition, cls, epochs, bs, shot, lr, patience, gamma, miu):
'''特征提取器'''
G = Extractor()
G.apply(weights_init_Extractor)
G.to(device)
'''标签分类器'''
LC = LabelClassifier(cls_num=cls)
LC.apply(weights_init_Classifier)
LC.to(device)
'''域分类器'''
DC = DomainClassifier()
DC.apply(weights_init_Classifier)
DC.to(device)
'''得到数据集'''
src_dataset, tar_dataset, val_dataset, test_dataset \
= get_dataset(src_condition, tar_condition)
'''DataLoader'''
src_loader = DataLoader(dataset=src_dataset, batch_size=bs, shuffle=False, drop_last=True)
tar_loader = DataLoader(dataset=tar_dataset, batch_size=bs, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, drop_last=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False)
criterion = nn.CrossEntropyLoss().to(device)
BCE = BinaryCrossEntropyLoss().to(device)
optimizer_g = torch.optim.Adam(G.parameters(), lr=lr)
optimizer_lc = torch.optim.Adam(LC.parameters(), lr=lr)
optimizer_dc = torch.optim.Adam(DC.parameters(), lr=lr)
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=patience)
scheduler_lc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_lc, mode="min", factor=0.5, patience=patience)
scheduler_dc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_dc, mode="min", factor=0.5, patience=patience)
def zero_grad_all():
optimizer_g.zero_grad()
optimizer_lc.zero_grad()
optimizer_dc.zero_grad()
src_acc_list = []
train_loss_list = []
val_acc_list = []
val_loss_list = []
for epoch in range(epochs):
epoch_start_time = time.time()
src_acc = 0.0
train_loss = 0.0
val_acc = 0.0
val_loss = 0.0
G.train()
LC.train()
DC.train()
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data)) in zip(enumerate(src_loader), enumerate(tar_loader)):
src_data, src_label = src_data.to(device), src_label.to(device)
tar_data = tar_data.to(device)
zero_grad_all()
T1 = (int)(0.2 * epochs)
T2 = (int)(0.5 * epochs)
src_feature, tar_feature = G(src_data, tar_data)
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
if epoch < T1:
pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
else:
pseudo_label = torch.tensor(obtain_label(tar_feature, tar_output, bs), dtype=torch.int64).cuda()
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
# mmdc权重
if epoch < T1:
miu_f = 0
# elif epoch > T1 and epoch < T2:
# miu_f = miu * (epoch - T1) / (T2 - T1)
else:
miu_f = miu
# 源数据的交叉熵损失
loss_src = criterion(src_output, src_label)
# 目标数据的伪标签交叉熵损失
loss_tar_pseudo = criterion(tar_output, pseudo_label)
# mmd损失
loss_mmdm = mmd.mmd_rbf_noaccelerate(src_data_mmd1, tar_data_mmd1)
if epoch < T1:
loss_mmdc = 0
else:
loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
# loss_mmdc = cal_mmdc_loss(src_data_mmd1, src_output, tar_data_mmd1, tar_output)
# loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
# loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
# 伪标签损失的权重
# if epoch < T1:
# beta_f = 0
# elif epoch > T1 and epoch < T2:
# beta_f = beta * (epoch - T1) / (T2 - T1)
# else:
# beta_f = beta
p = epoch / epochs
lamda = (2 / (1 + math.exp(-10 * p))) - 1
# gamma = (2 / (1 + math.exp(-10 * p))) - 1
# miu_f = (2 / (1 + math.exp(-10 * p))) - 1
loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
loss_G_LC = loss_src + gamma * loss_jmmd
# loss_G_LC = loss_src + beta_f * loss_tar_pseudo + gamma * loss_jmmd
loss_G_LC.backward()
optimizer_g.step()
optimizer_lc.step()
zero_grad_all()
#-----------------------------------------------
# 对抗域适应的损失
src_feature, tar_feature = G(src_data, tar_data)
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
# 源数据的交叉熵损失
loss_src = criterion(src_output, src_label)
# 目标数据的伪标签交叉熵损失
loss_tar_pseudo = criterion(tar_output, pseudo_label)
gradient_src = \
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
only_inputs=True)[0]
gradient_tar = \
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
only_inputs=True)[0]
gradients_adv = torch.cat((gradient_src, gradient_tar), dim=0)
domain_label_reverse = DC(gradients_adv, reverse=True)
loss_adv_r = BCE(domain_label_reverse)
loss_G_adv = lamda * loss_adv_r
# 更新特征提取器G参数
loss_G_adv.backward()
optimizer_g.step()
zero_grad_all()
#---------------------------------------------------------------------
src_feature, tar_feature = G(src_data, tar_data)
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
# 源数据的交叉熵损失
loss_src = criterion(src_output, src_label)
# 目标数据的伪标签交叉熵损失
loss_tar_pseudo = criterion(tar_output, pseudo_label)
gradient_src = \
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
only_inputs=True)[0]
gradient_tar = \
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
only_inputs=True)[0]
gradients = torch.cat((gradient_src, gradient_tar), dim=0)
domain_label = DC(gradients, reverse=False)
loss_adv = BCE(domain_label)
loss_DC = lamda * loss_adv
# 更新域分类器的参数
loss_DC.backward()
optimizer_dc.step()
zero_grad_all()
src_acc += np.sum(np.argmax(src_output.cpu().detach().numpy(), axis=1) == src_label.cpu().numpy())
train_loss += (loss_G_LC + loss_G_adv + loss_DC).item()
G.eval()
LC.eval()
DC.eval()
with torch.no_grad():
for batch_idx, (val_data, val_label) in enumerate(val_loader):
val_data, val_label = val_data.to(device), val_label.to(device)
_, val_feature = G(val_data, val_data)
_, _, _, val_output = LC(val_feature, val_feature)
loss = criterion(val_output, val_label)
val_acc += np.sum(np.argmax(val_output.cpu().detach().numpy(), axis=1) == val_label.cpu().numpy())
val_loss += loss.item()
scheduler_g.step(val_loss)
scheduler_lc.step(val_loss)
scheduler_dc.step(val_loss)
print("[{:03d}/{:03d}] {:2.2f} sec(s) src_acc: {:3.6f} train_loss: {:3.9f} | val_acc: {:3.6f} val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
epoch + 1, epochs, time.time() - epoch_start_time, \
src_acc / src_dataset.__len__(), train_loss / src_dataset.__len__(),
val_acc / val_dataset.__len__(), val_loss / val_dataset.__len__(),
optimizer_g.state_dict()['param_groups'][0]['lr']))
# 保存在验证集上loss最小的模型
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
# 如果精度大于最高精度,则保存
if val_acc_list.__len__() > 0 :
# if (val_acc / val_dataset.__len__()) >= max(val_acc_list):
if (val_acc / val_dataset.__len__()) > max(val_acc_list) or (val_loss / val_dataset.__len__()) < min(val_loss_list):
print("保存模型最佳模型成功")
G_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/G"
LC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/LC"
DC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/DC"
if not os.path.exists(G_path):
os.makedirs(G_path)
if not os.path.exists(LC_path):
os.makedirs(LC_path)
if not os.path.exists(DC_path):
os.makedirs(DC_path)
# 保存模型参数
torch.save(G.state_dict(), G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
torch.save(LC.state_dict(), LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
torch.save(DC.state_dict(), DC_path + "/DC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
src_acc_list.append(src_acc / src_dataset.__len__())
train_loss_list.append(train_loss / src_dataset.__len__())
val_acc_list.append(val_acc / val_dataset.__len__())
val_loss_list.append(val_loss / val_dataset.__len__())
'''保存的模型参数的路径'''
G_params_path = G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
LC_params_path = LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
from matplotlib import rcParams
config = {
"font.family": 'Times New Roman', # 设置字体类型
"axes.unicode_minus": False, # 解决负号无法显示的问题
"axes.labelsize": 13
}
rcParams.update(config)
# pic1 = plt.figure(figsize=(8, 6), dpi=200)
# plt.subplot(211)
# plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc')
# plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc')
# plt.ylim(0.3, 1.0) # 设置y轴范围
# plt.title('Training & Validation accuracy')
# plt.xlabel('epoch')
# plt.ylabel('accuracy')
# plt.legend(loc='lower right')
# plt.grid(alpha=0.4)
#
# plt.subplot(212)
# plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss')
# plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss')
# plt.ylim(0, 0.08) # 设置y轴范围
# plt.title('Training & Validation loss')
# plt.xlabel('epoch')
# plt.ylabel('loss')
# plt.legend(loc='upper right')
# plt.grid(alpha=0.4)
pic1 = plt.figure(figsize=(12, 6), dpi=200)
plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='Training Loss')
plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='Validation Loss')
plt.ylim(0, 0.08) # 设置y轴范围
plt.title('Training & Validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')
plt.grid(alpha=0.4)
# 获取当前时间戳
timestamp = int(time.time())
# 将时间戳转换为字符串
timestamp_str = str(timestamp)
plt.savefig(timestamp_str, dpi=200)
return G_params_path, LC_params_path
if __name__ == '__main__':
begin = time.time()
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
seed = 2
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
src_condition = 'A'
tar_condition = 'B'
'''训练'''
G_params_path, LC_params_path = train(device, src_condition, tar_condition, cls=5,
epochs=200, bs=50, shot=10, lr=0.002, patience=40, gamma=1, miu=0.5)
end = time.time()
'''测试'''
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
test(5, tar_condition, G_params_path, LC_params_path)
print("训练耗时:{:3.2f}s".format(end - begin))