pytorch更新
This commit is contained in:
parent
564ba3f669
commit
2c3e6c25a8
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/9 21:32
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/9 21:33
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/9 21:34
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/9 21:34
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/9 21:34
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/11/9 21:28
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""
|
||||
@Author: miykah
|
||||
@Email: miykah@163.com
|
||||
@FileName: load_data.py
|
||||
@DateTime: 2022/7/20 16:40
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
|
||||
'''正常Dataset类'''
|
||||
class Nor_Dataset(Dataset):
|
||||
def __init__(self, datas, labels=None):
|
||||
self.datas = torch.tensor(datas)
|
||||
if labels is not None:
|
||||
self.labels = torch.tensor(labels)
|
||||
else:
|
||||
self.labels = None
|
||||
def __getitem__(self, index):
|
||||
data = self.datas[index]
|
||||
if self.labels is not None:
|
||||
label = self.labels[index]
|
||||
return data, label
|
||||
return data
|
||||
def __len__(self):
|
||||
return len(self.datas)
|
||||
|
||||
'''未标记目标数据的Dataset类'''
|
||||
class Tar_U_Dataset(Dataset):
|
||||
def __init__(self, datas):
|
||||
self.datas = torch.tensor(datas)
|
||||
def __getitem__(self, index):
|
||||
data = self.datas[index]
|
||||
data_bar = data.clone().detach()
|
||||
data_bar2 = data.clone().detach()
|
||||
mu = 0
|
||||
sigma = 0.1
|
||||
'''对未标记目标数据加噪,得到data'和data'' '''
|
||||
for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])):
|
||||
data_bar[0, i] += random.gauss(mu, sigma)
|
||||
data_bar2[0, j] += random.gauss(mu, sigma)
|
||||
return data, data_bar, data_bar2
|
||||
def __len__(self):
|
||||
return len(self.datas)
|
||||
|
||||
def draw_signal_img(data, data_bar, data_bar2):
|
||||
pic = plt.figure(figsize=(12, 6), dpi=100)
|
||||
plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.subplot(3, 1, 1)
|
||||
plt.plot(data, 'b')
|
||||
plt.subplot(3, 1, 2)
|
||||
plt.plot(data_bar, 'b')
|
||||
plt.subplot(3, 1, 3)
|
||||
plt.plot(data_bar2, 'b')
|
||||
plt.show()
|
||||
|
||||
def get_dataset(src_condition, tar_condition):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||
|
||||
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
|
||||
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
|
||||
|
||||
tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy')
|
||||
|
||||
val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
|
||||
val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
|
||||
|
||||
test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy')
|
||||
test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy')
|
||||
|
||||
src_dataset = Nor_Dataset(src_data, src_label)
|
||||
tar_dataset = Nor_Dataset(tar_data)
|
||||
val_dataset = Nor_Dataset(val_data, val_label)
|
||||
test_dataset = Nor_Dataset(test_data, test_label)
|
||||
|
||||
return src_dataset, tar_dataset, val_dataset, test_dataset
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# pass
|
||||
# tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy")
|
||||
# tar_dataset = Nor_Dataset(datas=tar_data)
|
||||
# tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False)
|
||||
#
|
||||
# for batch_idx, (data) in enumerate(tar_loader):
|
||||
# print(data[0][0])
|
||||
# # print(data_bar[0][0])
|
||||
# # print(data_bar2[0][0])
|
||||
# if batch_idx == 0:
|
||||
# break
|
||||
src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy")
|
||||
src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy")
|
||||
src_dataset = Nor_Dataset(datas=src_data, labels=src_label)
|
||||
src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False)
|
||||
|
||||
for batch_idx, (data, label) in enumerate(src_loader):
|
||||
print(data[10][0])
|
||||
print(label[10])
|
||||
# print(data_bar[0][0])
|
||||
# print(data_bar2[0][0])
|
||||
if batch_idx == 0:
|
||||
break
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
@Author: miykah
|
||||
@Email: miykah@163.com
|
||||
@FileName: loss.py
|
||||
@DateTime: 2022/7/21 14:31
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as functional
|
||||
import torch.nn as nn
|
||||
import example.mmd as mmd
|
||||
|
||||
'''
|
||||
计算条件概率分布的差异
|
||||
mmdc = mmd_condition
|
||||
'''
|
||||
def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2):
|
||||
src_cls_10 = src_data_mmd1[3 * 0: 3 * 1]
|
||||
src_cls_11 = src_data_mmd1[3 * 1: 3 * 2]
|
||||
src_cls_12 = src_data_mmd1[3 * 2: 3 * 3]
|
||||
src_cls_13 = src_data_mmd1[3 * 3: 3 * 4]
|
||||
src_cls_14 = src_data_mmd1[3 * 4: 3 * 5]
|
||||
src_cls_15 = src_data_mmd1[3 * 5: 3 * 6]
|
||||
src_cls_16 = src_data_mmd1[3 * 6: 3 * 7]
|
||||
src_cls_17 = src_data_mmd1[3 * 7: 3 * 8]
|
||||
src_cls_18 = src_data_mmd1[3 * 8: 3 * 9]
|
||||
src_cls_19 = src_data_mmd1[3 * 9: 3 * 10]
|
||||
|
||||
tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1]
|
||||
tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2]
|
||||
tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3]
|
||||
tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4]
|
||||
tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5]
|
||||
tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6]
|
||||
tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7]
|
||||
tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8]
|
||||
tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9]
|
||||
tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10]
|
||||
|
||||
|
||||
src_cls_20 = src_data_mmd2[3 * 0: 3 * 1]
|
||||
src_cls_21 = src_data_mmd2[3 * 1: 3 * 2]
|
||||
src_cls_22 = src_data_mmd2[3 * 2: 3 * 3]
|
||||
src_cls_23 = src_data_mmd2[3 * 3: 3 * 4]
|
||||
src_cls_24 = src_data_mmd2[3 * 4: 3 * 5]
|
||||
src_cls_25 = src_data_mmd2[3 * 5: 3 * 6]
|
||||
src_cls_26 = src_data_mmd2[3 * 6: 3 * 7]
|
||||
src_cls_27 = src_data_mmd2[3 * 7: 3 * 8]
|
||||
src_cls_28 = src_data_mmd2[3 * 8: 3 * 9]
|
||||
src_cls_29 = src_data_mmd2[3 * 9: 3 * 10]
|
||||
|
||||
tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1]
|
||||
tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2]
|
||||
tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3]
|
||||
tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4]
|
||||
tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5]
|
||||
tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6]
|
||||
tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7]
|
||||
tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8]
|
||||
tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9]
|
||||
tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10]
|
||||
|
||||
mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10)
|
||||
mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11)
|
||||
mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12)
|
||||
mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13)
|
||||
mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14)
|
||||
mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15)
|
||||
mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16)
|
||||
mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17)
|
||||
mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18)
|
||||
mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19)
|
||||
|
||||
mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20)
|
||||
mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21)
|
||||
mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22)
|
||||
mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23)
|
||||
mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24)
|
||||
mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25)
|
||||
mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26)
|
||||
mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27)
|
||||
mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28)
|
||||
mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29)
|
||||
|
||||
mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19
|
||||
mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29
|
||||
|
||||
return (mmdc2) / 10
|
||||
# return (mmdc1 + mmdc2) / 20
|
||||
|
||||
'''得到源域每类特征,用于计算mmdc'''
|
||||
def get_src_mean_feature(src_feature, shot, cls):
|
||||
src_feature_list = []
|
||||
for i in range(cls):
|
||||
src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0)
|
||||
src_feature_list.append(src_feature_cls)
|
||||
return src_feature_list
|
||||
|
||||
def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls):
|
||||
src_feature_list = get_src_mean_feature(src_feature, shot, cls)
|
||||
pseudo_label = tar_pseudo_label.cpu().detach().numpy()
|
||||
mmdc = 0.0
|
||||
for i in range(batch_size):
|
||||
# mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
|
||||
mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
|
||||
return mmdc / batch_size
|
||||
|
||||
class BCE(nn.Module):
|
||||
eps = 1e-7
|
||||
def forward(self, prob1, prob2, simi):
|
||||
P = prob1.mul_(prob2)
|
||||
P = P.sum(1)
|
||||
P.mul_(simi).add_(simi.eq(-1).type_as(P))
|
||||
neglogP = -P.add_(BCE.eps).log_()
|
||||
return neglogP.mean()
|
||||
|
||||
class BinaryCrossEntropyLoss(nn.Module):
|
||||
""" Construct binary cross-entropy loss."""
|
||||
eps = 1e-7
|
||||
def forward(self, prob):
|
||||
# ds = torch.ones([bs, 1]).to(device) # domain label for source
|
||||
# dt = torch.zeros([bs, 1]).to(device) # domain label for target
|
||||
# di = torch.cat((ds, dt), dim=0).to(device)
|
||||
# neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps))
|
||||
neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps))
|
||||
return neglogP.mean()
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python
|
||||
# encoding: utf-8
|
||||
|
||||
import torch
|
||||
|
||||
# Consider linear time MMD with a linear kernel:
|
||||
# K(f(x), f(y)) = f(x)^Tf(y)
|
||||
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
|
||||
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
|
||||
#
|
||||
# f_of_X: batch_size * k
|
||||
# f_of_Y: batch_size * k
|
||||
def mmd_linear(f_of_X, f_of_Y):
|
||||
delta = f_of_X - f_of_Y
|
||||
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
|
||||
return loss
|
||||
|
||||
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
n_samples = int(source.size()[0])+int(target.size()[0])
|
||||
total = torch.cat([source, target], dim=0)
|
||||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
L2_distance = ((total0-total1)**2).sum(2)
|
||||
if fix_sigma:
|
||||
bandwidth = fix_sigma
|
||||
else:
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
|
||||
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)#/len(kernel_val)
|
||||
|
||||
|
||||
def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
batch_size = int(source.size()[0])
|
||||
kernels = guassian_kernel(source, target,
|
||||
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
s1, s2 = i, (i+1)%batch_size
|
||||
t1, t2 = s1+batch_size, s2+batch_size
|
||||
loss += kernels[s1, s2] + kernels[t1, t2]
|
||||
loss -= kernels[s1, t2] + kernels[s2, t1]
|
||||
return loss / float(batch_size)
|
||||
|
||||
def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
batch_size = int(source.size()[0])
|
||||
kernels = guassian_kernel(source, target,
|
||||
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
|
||||
XX = kernels[:batch_size, :batch_size]
|
||||
YY = kernels[batch_size:, batch_size:]
|
||||
XY = kernels[:batch_size, batch_size:]
|
||||
YX = kernels[batch_size:, :batch_size]
|
||||
loss = torch.mean(XX + YY - XY -YX)
|
||||
return loss
|
||||
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
@Author: miykah
|
||||
@Email: miykah@163.com
|
||||
@FileName: model.py
|
||||
@DateTime: 2022/7/20 21:18
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
class GradReverse(Function):
|
||||
def __init__(self):
|
||||
self.lambd = 1.0
|
||||
|
||||
def forward(self, x):
|
||||
return x.view_as(x)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return (grad_output * - 1.0)
|
||||
|
||||
def grad_reverse(x):
|
||||
return GradReverse.apply(x)
|
||||
# return GradReverse(lambd)(x)
|
||||
|
||||
'''特征提取器'''
|
||||
class Extractor(nn.Module):
|
||||
def __init__(self):
|
||||
super(Extractor, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool1d(2) # (32 * 1024)
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool1d(2) # (32 * 512)
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool1d(2) # (32 * 256)
|
||||
)
|
||||
|
||||
def forward(self, src_data, tar_data):
|
||||
src_data = self.conv1(src_data)
|
||||
src_data = self.conv2(src_data)
|
||||
src_feature = self.conv3(src_data)
|
||||
|
||||
tar_data = self.conv1(tar_data)
|
||||
tar_data = self.conv2(tar_data)
|
||||
tar_feature = self.conv3(tar_data)
|
||||
return src_feature, tar_feature
|
||||
|
||||
'''标签分类器'''
|
||||
class LabelClassifier(nn.Module):
|
||||
def __init__(self, cls_num):
|
||||
super(LabelClassifier, self).__init__()
|
||||
self.fc1 = nn.Sequential(
|
||||
nn.Flatten(), # (8192,)
|
||||
nn.Linear(in_features=8192, out_features=256),
|
||||
)
|
||||
self.fc2 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=256, out_features=cls_num)
|
||||
)
|
||||
def forward(self, src_feature, tar_feature):
|
||||
src_data_mmd1 = self.fc1(src_feature)
|
||||
src_output = self.fc2(src_data_mmd1)
|
||||
|
||||
tar_data_mmd1 = self.fc1(tar_feature)
|
||||
tar_output = self.fc2(tar_data_mmd1)
|
||||
return src_data_mmd1, src_output, tar_data_mmd1, tar_output
|
||||
|
||||
'''分类器'''
|
||||
class DomainClassifier(nn.Module):
|
||||
def __init__(self, temp=0.05):
|
||||
super(DomainClassifier, self).__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(in_features=8192, out_features=512),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=512, out_features=128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=128, out_features=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.temp = temp
|
||||
def forward(self, x, reverse=False):
|
||||
if reverse:
|
||||
x = grad_reverse(x)
|
||||
output = self.fc(x)
|
||||
return output
|
||||
|
||||
'''初始化网络权重'''
|
||||
def weights_init_Extractor(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def weights_init_Classifier(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
|
@ -0,0 +1,298 @@
|
|||
"""
|
||||
@Author: miykah
|
||||
@Email: miykah@163.com
|
||||
@FileName: read_data_cwru.py
|
||||
@DateTime: 2022/7/20 15:58
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import scipy.io as scio
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
def dat_to_numpy(samples_num, sample_length):
|
||||
'''根据每个类别需要的样本数和样本长度计算采样步长'''
|
||||
stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1))
|
||||
|
||||
conditions = ['A', 'B', 'C']
|
||||
for condition in conditions:
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition
|
||||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition
|
||||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
|
||||
|
||||
if condition == 'A':
|
||||
start_idx = 1
|
||||
end_idx = 5
|
||||
# start_idx = 13
|
||||
# end_idx = 17
|
||||
# start_idx = 25
|
||||
# end_idx = 29
|
||||
# start_idx = 9
|
||||
# end_idx = 13
|
||||
elif condition == 'B':
|
||||
start_idx = 5
|
||||
end_idx = 9
|
||||
# start_idx = 17
|
||||
# end_idx = 21
|
||||
# start_idx = 29
|
||||
# end_idx = 33
|
||||
elif condition == 'C':
|
||||
start_idx = 9
|
||||
end_idx = 13
|
||||
# start_idx = 21
|
||||
# end_idx = 25
|
||||
# start_idx = 33
|
||||
# end_idx = 37
|
||||
# start_idx = 25
|
||||
# end_idx = 29
|
||||
|
||||
dir_names = os.listdir(root_dir)
|
||||
print(dir_names)
|
||||
'''
|
||||
故障类型
|
||||
['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速',
|
||||
'平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速',
|
||||
'平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速',
|
||||
'平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速']
|
||||
'''
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
for dir_name in dir_names:
|
||||
data_list = []
|
||||
for i in range(start_idx, end_idx):
|
||||
if i < 10:
|
||||
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat'
|
||||
else:
|
||||
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat'
|
||||
data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3]
|
||||
data_list.append(data)
|
||||
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024)
|
||||
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048)
|
||||
data_one_cls = np.array(data_list).reshape(-1)
|
||||
n = math.floor((204800 * 4 - sample_length) / stride + 1)
|
||||
list = []
|
||||
for i in range(n):
|
||||
start = i * stride
|
||||
end = i * stride + sample_length
|
||||
list.append(data_one_cls[start: end])
|
||||
data_one_cls = np.array(list).reshape(-1, 1, sample_length)
|
||||
# 打乱数据
|
||||
shuffle_ix = np.random.permutation(np.arange(len(data_one_cls)))
|
||||
data_one_cls = data_one_cls[shuffle_ix]
|
||||
print(data_one_cls.shape)
|
||||
np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls)
|
||||
|
||||
|
||||
def draw_signal_img(condition):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
|
||||
file_names = os.listdir(root_dir)
|
||||
pic = plt.figure(figsize=(12, 6), dpi=200)
|
||||
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||||
# plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rc('font',family='Times New Roman')
|
||||
clses = ['(a)', '(b)', '(c)', '(d)']
|
||||
for file_name, i, cls in zip(file_names, range(4), clses):
|
||||
data = np.load(root_dir + '\\' + file_name)[0].reshape(-1)
|
||||
print(data.shape, file_name)
|
||||
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
|
||||
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
|
||||
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
|
||||
plt.subplot(2, 2, i + 1)
|
||||
plt.title(cls, fontsize=20)
|
||||
plt.xlabel("Time (s)", fontsize=20)
|
||||
plt.ylabel("Amplitude (m/s²)", fontsize=20)
|
||||
plt.xlim(0, 0.16)
|
||||
plt.ylim(-3, 3)
|
||||
plt.yticks(np.array([-2, 0, 2]), fontsize=20)
|
||||
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20)
|
||||
plt.plot(data)
|
||||
plt.show()
|
||||
|
||||
def draw_signal_img2():
|
||||
conditions = ['A', 'B', 'C']
|
||||
pic = plt.figure(figsize=(12, 6), dpi=600)
|
||||
plt.text(x=1, y=3, s='A')
|
||||
i = 1
|
||||
for condition in conditions:
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
|
||||
file_names = os.listdir(root_dir)
|
||||
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||||
# plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.rc('font',family='Times New Roman')
|
||||
clses = ['(a)', '(b)', '(c)', '(d)', '(e)']
|
||||
for file_name, cls in zip(file_names, clses):
|
||||
data = np.load(root_dir + '\\' + file_name)[3].reshape(-1)
|
||||
print(data.shape, file_name)
|
||||
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
|
||||
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
|
||||
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
|
||||
plt.subplot(3, 5, i)
|
||||
plt.title(cls, fontsize=15)
|
||||
plt.xlabel("Time (s)", fontsize=15)
|
||||
plt.ylabel("Amplitude (m/s²)", fontsize=15)
|
||||
plt.xlim(0, 0.16)
|
||||
plt.ylim(-5, 5)
|
||||
plt.yticks(np.array([-3, 0, 3]), fontsize=15)
|
||||
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15)
|
||||
plt.plot(data)
|
||||
i += 1
|
||||
plt.show()
|
||||
|
||||
'''得到源域数据和标签'''
|
||||
def get_src_data(src_condition, shot):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
|
||||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
|
||||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
|
||||
|
||||
file_names = os.listdir(root_dir)
|
||||
# print(file_names)
|
||||
datas = []
|
||||
labels = []
|
||||
for i in range(600 // shot):
|
||||
for file_name, cls in zip(file_names, range(5)):
|
||||
data = np.load(root_dir + '\\' + file_name)
|
||||
data_cut = data[shot * i: shot * (i + 1), :, :]
|
||||
label_cut = np.array([cls] * shot, dtype='int64')
|
||||
datas.append(data_cut)
|
||||
labels.append(label_cut)
|
||||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||
labels = np.array(labels).reshape(-1)
|
||||
print(datas.shape, labels.shape)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||
|
||||
'''得到没有标签的目标域数据'''
|
||||
def get_tar_data(tar_condition):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
|
||||
file_names = os.listdir(root_dir)
|
||||
datas = []
|
||||
for file_name in file_names:
|
||||
data = np.load(root_dir + '\\' + file_name)
|
||||
data_cut = data[0: 600, :, :]
|
||||
datas.append(data_cut)
|
||||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||
print("datas.shape: {}".format(datas.shape))
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||
|
||||
def get_tar_data2(tar_condition, shot):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
|
||||
file_names = os.listdir(root_dir)
|
||||
datas = []
|
||||
labels = []
|
||||
for i in range(600 // shot):
|
||||
for file_name, cls in zip(file_names, range(5)):
|
||||
data = np.load(root_dir + '\\' + file_name)
|
||||
data_cut = data[shot * i: shot * (i + 1), :, :]
|
||||
label_cut = np.array([cls] * shot, dtype='int64')
|
||||
datas.append(data_cut)
|
||||
labels.append(label_cut)
|
||||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||
labels = np.array(labels).reshape(-1)
|
||||
print(datas.shape, labels.shape)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||
|
||||
'''得到验证集(每个类别100个样本)'''
|
||||
def get_val_data(tar_condition):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
|
||||
file_names = os.listdir(root_dir)
|
||||
datas = []
|
||||
labels = []
|
||||
for file_name, cls in zip(file_names, range(5)):
|
||||
data = np.load(root_dir + '\\' + file_name)
|
||||
data_cut = data[600: 700, :, :]
|
||||
label_cut = np.array([cls] * 100, dtype='int64')
|
||||
datas.append(data_cut)
|
||||
labels.append(label_cut)
|
||||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||
labels = np.array(labels).reshape(-1)
|
||||
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||
|
||||
'''得到测试集(100个样本)'''
|
||||
def get_test_data(tar_condition):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
|
||||
file_names = os.listdir(root_dir)
|
||||
datas = []
|
||||
labels = []
|
||||
for file_name, cls in zip(file_names, range(5)):
|
||||
data = np.load(root_dir + '\\' + file_name)
|
||||
data_cut = data[700: , :, :]
|
||||
label_cut = np.array([cls] * 100, dtype='int64')
|
||||
datas.append(data_cut)
|
||||
labels.append(label_cut)
|
||||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||||
labels = np.array(labels).reshape(-1)
|
||||
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# dat_to_numpy(samples_num=800, sample_length=2048)
|
||||
# draw_signal_img('C')
|
||||
# draw_signal_img2()
|
||||
# get_src_data('A', shot=10)
|
||||
# get_src_data('B', shot=10)
|
||||
# get_src_data('C', shot=10)
|
||||
get_tar_data('A')
|
||||
get_tar_data('B')
|
||||
get_tar_data('C')
|
||||
get_val_data('A')
|
||||
get_val_data('B')
|
||||
get_val_data('C')
|
||||
get_test_data('A')
|
||||
get_test_data('B')
|
||||
get_test_data('C')
|
||||
|
||||
get_src_data('A', shot=10)
|
||||
get_src_data('B', shot=10)
|
||||
get_src_data('C', shot=10)
|
||||
get_tar_data2('A', shot=10)
|
||||
get_tar_data2('B', shot=10)
|
||||
get_tar_data2('C', shot=10)
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
@Author: miykah
|
||||
@Email: miykah@163.com
|
||||
@FileName: test.py
|
||||
@DateTime: 2022/7/9 14:15
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from DDS_GADAN.load_data import get_dataset
|
||||
from DDS_GADAN.load_data import Nor_Dataset
|
||||
from DDS_GADAN.model import Extractor, LabelClassifier
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.manifold import TSNE
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def load_data(tar_condition):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
|
||||
|
||||
test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
|
||||
test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
|
||||
|
||||
return test_data, test_label
|
||||
|
||||
def tsne_2d_generate(cls, data, labels, pic_title):
|
||||
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data)
|
||||
x, y = tsne2D[:, 0], tsne2D[:, 1]
|
||||
pic = plt.figure()
|
||||
ax1 = pic.add_subplot()
|
||||
ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色,因为标签有9类
|
||||
plt.title(pic_title)
|
||||
plt.show()
|
||||
|
||||
def tsne_2d_generate1(cls, data, labels, pic_title):
|
||||
parameters = {'figure.dpi': 600,
|
||||
'figure.figsize': (4, 3),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 10,
|
||||
'ytick.labelsize': 10,
|
||||
'legend.fontsize': 11.3,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data)
|
||||
tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0)
|
||||
tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min)
|
||||
s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :]
|
||||
pic = plt.figure()
|
||||
# ax1 = pic.add_subplot()
|
||||
plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色,因为标签有9类
|
||||
plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色,因为标签有9类
|
||||
plt.title(pic_title, fontsize=10)
|
||||
# plt.xticks([])
|
||||
# plt.yticks([])
|
||||
plt.show()
|
||||
|
||||
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
|
||||
# # 画混淆矩阵
|
||||
# confusion = confusion_matrix(true_labels, predict_labels)
|
||||
# # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
|
||||
# plt.figure(figsize=(6.4,6.4), dpi=100)
|
||||
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
|
||||
# # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
|
||||
# indices = range(len(confusion))
|
||||
# classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
|
||||
# # for i in range(cls):
|
||||
# # classes.append(str(i))
|
||||
# # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# # plt.yticks(indices, classes, rotation=45)
|
||||
# plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
# plt.ylabel('Actual label')
|
||||
# plt.xlabel('Predicted label')
|
||||
# plt.title('confusion matrix')
|
||||
# plt.show()
|
||||
sns.set(font_scale=1.5)
|
||||
parameters = {'figure.dpi': 600,
|
||||
'figure.figsize': (5, 5),
|
||||
'savefig.dpi': 600,
|
||||
'xtick.direction': 'in',
|
||||
'ytick.direction': 'in',
|
||||
'xtick.labelsize': 20,
|
||||
'ytick.labelsize': 20,
|
||||
'legend.fontsize': 11.3,
|
||||
}
|
||||
plt.rcParams.update(parameters)
|
||||
plt.figure()
|
||||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||||
# 画混淆矩阵
|
||||
confusion = confusion_matrix(true_labels, predict_labels)
|
||||
# confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
|
||||
plt.figure()
|
||||
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
|
||||
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
|
||||
indices = range(len(confusion))
|
||||
classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
|
||||
# for i in range(cls):
|
||||
# classes.append(str(i))
|
||||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||||
# plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
# plt.yticks(indices, classes, rotation=45)
|
||||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||||
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||||
plt.ylabel('Actual label', fontsize=20)
|
||||
plt.xlabel('Predicted label', fontsize=20)
|
||||
# plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def test(cls, tar_condition, G_params_path, LC_params_path):
|
||||
test_data, test_label = load_data(tar_condition)
|
||||
test_dataset = Nor_Dataset(test_data, test_label)
|
||||
|
||||
batch_size = len(test_dataset)
|
||||
|
||||
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 加载网络
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
G = Extractor().to(device)
|
||||
LC = LabelClassifier(cls_num=cls).to(device)
|
||||
G.load_state_dict(
|
||||
torch.load(G_params_path, map_location=device)
|
||||
)
|
||||
LC.load_state_dict(
|
||||
torch.load(LC_params_path, map_location=device)
|
||||
)
|
||||
# print(net)
|
||||
# params_num = sum(param.numel() for param in net.parameters_bak())
|
||||
# print('参数数量:{}'.format(params_num))
|
||||
|
||||
test_acc = 0.0
|
||||
|
||||
G.eval()
|
||||
LC.eval()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, label) in enumerate(test_loader):
|
||||
data, label = data.to(device), label.to(device)
|
||||
_, feature = G(data, data)
|
||||
_, _, _, output = LC(feature, feature)
|
||||
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
|
||||
|
||||
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
|
||||
labels = label.cpu().numpy()
|
||||
|
||||
predictions = output.cpu().detach().numpy()
|
||||
|
||||
tsne_2d_generate(cls, predictions, labels, "output of neural network")
|
||||
|
||||
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
|
||||
|
||||
print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__()))
|
||||
|
||||
def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path):
|
||||
if torch.cuda.is_available():
|
||||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||
else:
|
||||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||||
|
||||
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
|
||||
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
|
||||
|
||||
tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy')
|
||||
tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy')
|
||||
|
||||
src_dataset = Nor_Dataset(src_data, src_label)
|
||||
tar_dataset = Nor_Dataset(tar_data, tar_label)
|
||||
src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True)
|
||||
tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True)
|
||||
|
||||
# 加载网络
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
G = Extractor().to(device)
|
||||
LC = LabelClassifier(cls_num=cls).to(device)
|
||||
G.load_state_dict(
|
||||
torch.load(G_params_path, map_location=device)
|
||||
)
|
||||
LC.load_state_dict(
|
||||
torch.load(LC_params_path, map_location=device)
|
||||
)
|
||||
# print(net)
|
||||
# params_num = sum(param.numel() for param in net.parameters_bak())
|
||||
# print('参数数量:{}'.format(params_num))
|
||||
|
||||
test_acc = 0.0
|
||||
|
||||
G.eval()
|
||||
LC.eval()
|
||||
with torch.no_grad():
|
||||
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)):
|
||||
src_data, src_label = src_data.to(device), src_label.to(device)
|
||||
tar_data, tar_label = tar_data.to(device), tar_label.to(device)
|
||||
data = torch.concat((src_data, tar_data), dim=0)
|
||||
label = torch.concat((src_label, tar_label), dim=0)
|
||||
_, feature = G(data, data)
|
||||
_, _, fc1, output = LC(feature, feature)
|
||||
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
|
||||
|
||||
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
|
||||
labels = label.cpu().numpy()
|
||||
|
||||
outputs = output.cpu().detach().numpy()
|
||||
fc1_outputs = fc1.cpu().detach().numpy()
|
||||
break
|
||||
|
||||
tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN")
|
||||
|
||||
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
|
||||
|
||||
print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__())))
|
||||
|
||||
if __name__ == '__main__':
|
||||
# pass
|
||||
test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
|
||||
'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
|
||||
|
||||
# test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
|
||||
# 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
|
||||
|
|
@ -0,0 +1,357 @@
|
|||
"""
|
||||
@Author: miykah
|
||||
@Email: miykah@163.com
|
||||
@FileName: train.py
|
||||
@DateTime: 2022/7/20 20:22
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from example.load_data import get_dataset, draw_signal_img
|
||||
from example.model import Extractor, LabelClassifier, DomainClassifier, weights_init_Classifier, weights_init_Extractor
|
||||
from example.loss import cal_mmdc_loss, BinaryCrossEntropyLoss, get_mmdc
|
||||
import example.mmd as mmd
|
||||
from example.test import test, test1
|
||||
from scipy.spatial.distance import cdist
|
||||
import math
|
||||
|
||||
def obtain_label(feature, output, bs):
|
||||
with torch.no_grad():
|
||||
all_fea = feature.reshape(bs, -1).float().cpu()
|
||||
all_output = output.float().cpu()
|
||||
# all_label = label.float()
|
||||
all_output = nn.Softmax(dim=1)(all_output)
|
||||
_, predict = torch.max(all_output, 1)
|
||||
# accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
|
||||
|
||||
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
|
||||
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
|
||||
all_fea = all_fea.float().cpu().numpy()
|
||||
|
||||
K = all_output.size(1)
|
||||
aff = all_output.float().cpu().numpy()
|
||||
initc = aff.transpose().dot(all_fea)
|
||||
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
|
||||
dd = cdist(all_fea, initc, 'cosine')
|
||||
pred_label = dd.argmin(axis=1)
|
||||
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
|
||||
|
||||
for round in range(1):
|
||||
aff = np.eye(K)[pred_label]
|
||||
initc = aff.transpose().dot(all_fea)
|
||||
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
|
||||
dd = cdist(all_fea, initc, 'cosine')
|
||||
pred_label = dd.argmin(axis=1)
|
||||
# acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
|
||||
|
||||
# log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
|
||||
# args.out_file.write(log_str + '\n')
|
||||
# args.out_file.flush()
|
||||
# print(log_str + '\n')
|
||||
return pred_label.astype('int')
|
||||
|
||||
|
||||
def train(device, src_condition, tar_condition, cls, epochs, bs, shot, lr, patience, gamma, miu):
|
||||
'''特征提取器'''
|
||||
G = Extractor()
|
||||
G.apply(weights_init_Extractor)
|
||||
G.to(device)
|
||||
'''标签分类器'''
|
||||
LC = LabelClassifier(cls_num=cls)
|
||||
LC.apply(weights_init_Classifier)
|
||||
LC.to(device)
|
||||
'''域分类器'''
|
||||
DC = DomainClassifier()
|
||||
DC.apply(weights_init_Classifier)
|
||||
DC.to(device)
|
||||
|
||||
'''得到数据集'''
|
||||
src_dataset, tar_dataset, val_dataset, test_dataset \
|
||||
= get_dataset(src_condition, tar_condition)
|
||||
'''DataLoader'''
|
||||
src_loader = DataLoader(dataset=src_dataset, batch_size=bs, shuffle=False, drop_last=True)
|
||||
tar_loader = DataLoader(dataset=tar_dataset, batch_size=bs, shuffle=True, drop_last=True)
|
||||
val_loader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, drop_last=False)
|
||||
test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False)
|
||||
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
BCE = BinaryCrossEntropyLoss().to(device)
|
||||
|
||||
optimizer_g = torch.optim.Adam(G.parameters(), lr=lr)
|
||||
optimizer_lc = torch.optim.Adam(LC.parameters(), lr=lr)
|
||||
optimizer_dc = torch.optim.Adam(DC.parameters(), lr=lr)
|
||||
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=patience)
|
||||
scheduler_lc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_lc, mode="min", factor=0.5, patience=patience)
|
||||
scheduler_dc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_dc, mode="min", factor=0.5, patience=patience)
|
||||
|
||||
def zero_grad_all():
|
||||
optimizer_g.zero_grad()
|
||||
optimizer_lc.zero_grad()
|
||||
optimizer_dc.zero_grad()
|
||||
|
||||
src_acc_list = []
|
||||
train_loss_list = []
|
||||
val_acc_list = []
|
||||
val_loss_list = []
|
||||
|
||||
for epoch in range(epochs):
|
||||
epoch_start_time = time.time()
|
||||
src_acc = 0.0
|
||||
train_loss = 0.0
|
||||
val_acc = 0.0
|
||||
val_loss = 0.0
|
||||
|
||||
G.train()
|
||||
LC.train()
|
||||
DC.train()
|
||||
|
||||
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data)) in zip(enumerate(src_loader), enumerate(tar_loader)):
|
||||
src_data, src_label = src_data.to(device), src_label.to(device)
|
||||
tar_data = tar_data.to(device)
|
||||
zero_grad_all()
|
||||
|
||||
T1 = (int)(0.2 * epochs)
|
||||
T2 = (int)(0.5 * epochs)
|
||||
|
||||
src_feature, tar_feature = G(src_data, tar_data)
|
||||
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
||||
if epoch < T1:
|
||||
pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||
else:
|
||||
pseudo_label = torch.tensor(obtain_label(tar_feature, tar_output, bs), dtype=torch.int64).cuda()
|
||||
|
||||
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||
# mmdc权重
|
||||
if epoch < T1:
|
||||
miu_f = 0
|
||||
# elif epoch > T1 and epoch < T2:
|
||||
# miu_f = miu * (epoch - T1) / (T2 - T1)
|
||||
else:
|
||||
miu_f = miu
|
||||
|
||||
# 源数据的交叉熵损失
|
||||
loss_src = criterion(src_output, src_label)
|
||||
# 目标数据的伪标签交叉熵损失
|
||||
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
||||
# mmd损失
|
||||
loss_mmdm = mmd.mmd_rbf_noaccelerate(src_data_mmd1, tar_data_mmd1)
|
||||
if epoch < T1:
|
||||
loss_mmdc = 0
|
||||
else:
|
||||
loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
|
||||
# loss_mmdc = cal_mmdc_loss(src_data_mmd1, src_output, tar_data_mmd1, tar_output)
|
||||
# loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls)
|
||||
# loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
|
||||
|
||||
# 伪标签损失的权重
|
||||
# if epoch < T1:
|
||||
# beta_f = 0
|
||||
# elif epoch > T1 and epoch < T2:
|
||||
# beta_f = beta * (epoch - T1) / (T2 - T1)
|
||||
# else:
|
||||
# beta_f = beta
|
||||
|
||||
p = epoch / epochs
|
||||
lamda = (2 / (1 + math.exp(-10 * p))) - 1
|
||||
# gamma = (2 / (1 + math.exp(-10 * p))) - 1
|
||||
# miu_f = (2 / (1 + math.exp(-10 * p))) - 1
|
||||
|
||||
loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm
|
||||
|
||||
loss_G_LC = loss_src + gamma * loss_jmmd
|
||||
# loss_G_LC = loss_src + beta_f * loss_tar_pseudo + gamma * loss_jmmd
|
||||
loss_G_LC.backward()
|
||||
optimizer_g.step()
|
||||
optimizer_lc.step()
|
||||
zero_grad_all()
|
||||
#-----------------------------------------------
|
||||
# 对抗域适应的损失
|
||||
src_feature, tar_feature = G(src_data, tar_data)
|
||||
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
||||
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||
# 源数据的交叉熵损失
|
||||
loss_src = criterion(src_output, src_label)
|
||||
# 目标数据的伪标签交叉熵损失
|
||||
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
||||
|
||||
gradient_src = \
|
||||
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
gradient_tar = \
|
||||
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
gradients_adv = torch.cat((gradient_src, gradient_tar), dim=0)
|
||||
|
||||
domain_label_reverse = DC(gradients_adv, reverse=True)
|
||||
loss_adv_r = BCE(domain_label_reverse)
|
||||
loss_G_adv = lamda * loss_adv_r
|
||||
# 更新特征提取器G参数
|
||||
loss_G_adv.backward()
|
||||
optimizer_g.step()
|
||||
zero_grad_all()
|
||||
#---------------------------------------------------------------------
|
||||
src_feature, tar_feature = G(src_data, tar_data)
|
||||
src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature)
|
||||
# pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签
|
||||
# 源数据的交叉熵损失
|
||||
loss_src = criterion(src_output, src_label)
|
||||
# 目标数据的伪标签交叉熵损失
|
||||
loss_tar_pseudo = criterion(tar_output, pseudo_label)
|
||||
|
||||
gradient_src = \
|
||||
torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
gradient_tar = \
|
||||
torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
|
||||
gradients = torch.cat((gradient_src, gradient_tar), dim=0)
|
||||
domain_label = DC(gradients, reverse=False)
|
||||
loss_adv = BCE(domain_label)
|
||||
loss_DC = lamda * loss_adv
|
||||
# 更新域分类器的参数
|
||||
loss_DC.backward()
|
||||
optimizer_dc.step()
|
||||
zero_grad_all()
|
||||
|
||||
src_acc += np.sum(np.argmax(src_output.cpu().detach().numpy(), axis=1) == src_label.cpu().numpy())
|
||||
train_loss += (loss_G_LC + loss_G_adv + loss_DC).item()
|
||||
|
||||
G.eval()
|
||||
LC.eval()
|
||||
DC.eval()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (val_data, val_label) in enumerate(val_loader):
|
||||
val_data, val_label = val_data.to(device), val_label.to(device)
|
||||
_, val_feature = G(val_data, val_data)
|
||||
_, _, _, val_output = LC(val_feature, val_feature)
|
||||
loss = criterion(val_output, val_label)
|
||||
|
||||
val_acc += np.sum(np.argmax(val_output.cpu().detach().numpy(), axis=1) == val_label.cpu().numpy())
|
||||
val_loss += loss.item()
|
||||
|
||||
scheduler_g.step(val_loss)
|
||||
scheduler_lc.step(val_loss)
|
||||
scheduler_dc.step(val_loss)
|
||||
|
||||
print("[{:03d}/{:03d}] {:2.2f} sec(s) src_acc: {:3.6f} train_loss: {:3.9f} | val_acc: {:3.6f} val_loss: {:3.9f} | Learning rate : {:3.6f}".format(
|
||||
epoch + 1, epochs, time.time() - epoch_start_time, \
|
||||
src_acc / src_dataset.__len__(), train_loss / src_dataset.__len__(),
|
||||
val_acc / val_dataset.__len__(), val_loss / val_dataset.__len__(),
|
||||
optimizer_g.state_dict()['param_groups'][0]['lr']))
|
||||
|
||||
# 保存在验证集上loss最小的模型
|
||||
# if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
||||
# 如果精度大于最高精度,则保存
|
||||
if val_acc_list.__len__() > 0 :
|
||||
# if (val_acc / val_dataset.__len__()) >= max(val_acc_list):
|
||||
if (val_acc / val_dataset.__len__()) > max(val_acc_list) or (val_loss / val_dataset.__len__()) < min(val_loss_list):
|
||||
print("保存模型最佳模型成功")
|
||||
G_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/G"
|
||||
LC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/LC"
|
||||
DC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/DC"
|
||||
if not os.path.exists(G_path):
|
||||
os.makedirs(G_path)
|
||||
if not os.path.exists(LC_path):
|
||||
os.makedirs(LC_path)
|
||||
if not os.path.exists(DC_path):
|
||||
os.makedirs(DC_path)
|
||||
# 保存模型参数
|
||||
torch.save(G.state_dict(), G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
||||
torch.save(LC.state_dict(), LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
||||
torch.save(DC.state_dict(), DC_path + "/DC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl")
|
||||
|
||||
src_acc_list.append(src_acc / src_dataset.__len__())
|
||||
train_loss_list.append(train_loss / src_dataset.__len__())
|
||||
val_acc_list.append(val_acc / val_dataset.__len__())
|
||||
val_loss_list.append(val_loss / val_dataset.__len__())
|
||||
|
||||
'''保存的模型参数的路径'''
|
||||
G_params_path = G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
|
||||
LC_params_path = LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl"
|
||||
|
||||
|
||||
from matplotlib import rcParams
|
||||
|
||||
config = {
|
||||
"font.family": 'Times New Roman', # 设置字体类型
|
||||
"axes.unicode_minus": False, # 解决负号无法显示的问题
|
||||
"axes.labelsize": 13
|
||||
}
|
||||
rcParams.update(config)
|
||||
|
||||
# pic1 = plt.figure(figsize=(8, 6), dpi=200)
|
||||
# plt.subplot(211)
|
||||
# plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc')
|
||||
# plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc')
|
||||
# plt.ylim(0.3, 1.0) # 设置y轴范围
|
||||
# plt.title('Training & Validation accuracy')
|
||||
# plt.xlabel('epoch')
|
||||
# plt.ylabel('accuracy')
|
||||
# plt.legend(loc='lower right')
|
||||
# plt.grid(alpha=0.4)
|
||||
#
|
||||
# plt.subplot(212)
|
||||
# plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss')
|
||||
# plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss')
|
||||
# plt.ylim(0, 0.08) # 设置y轴范围
|
||||
# plt.title('Training & Validation loss')
|
||||
# plt.xlabel('epoch')
|
||||
# plt.ylabel('loss')
|
||||
# plt.legend(loc='upper right')
|
||||
# plt.grid(alpha=0.4)
|
||||
|
||||
pic1 = plt.figure(figsize=(12, 6), dpi=200)
|
||||
|
||||
plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='Training Loss')
|
||||
plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='Validation Loss')
|
||||
plt.ylim(0, 0.08) # 设置y轴范围
|
||||
plt.title('Training & Validation loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('loss')
|
||||
plt.legend(loc='upper right')
|
||||
plt.grid(alpha=0.4)
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = int(time.time())
|
||||
|
||||
# 将时间戳转换为字符串
|
||||
timestamp_str = str(timestamp)
|
||||
plt.savefig(timestamp_str, dpi=200)
|
||||
|
||||
return G_params_path, LC_params_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
begin = time.time()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
seed = 2
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
src_condition = 'A'
|
||||
tar_condition = 'B'
|
||||
'''训练'''
|
||||
G_params_path, LC_params_path = train(device, src_condition, tar_condition, cls=5,
|
||||
epochs=200, bs=50, shot=10, lr=0.002, patience=40, gamma=1, miu=0.5)
|
||||
|
||||
end = time.time()
|
||||
|
||||
'''测试'''
|
||||
# test1(5, src_condition, tar_condition, G_params_path, LC_params_path)
|
||||
test(5, tar_condition, G_params_path, LC_params_path)
|
||||
|
||||
print("训练耗时:{:3.2f}s".format(end - begin))
|
||||
Loading…
Reference in New Issue