From 2c3e6c25a8beb66f1175807d51e4318eefcd8cd3 Mon Sep 17 00:00:00 2001 From: kevinding1125 <745518019@qq.com> Date: Thu, 9 Nov 2023 21:42:09 +0800 Subject: [PATCH] =?UTF-8?q?pytorch=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_example/RUL/__init__.py | 8 + pytorch_example/RUL/otherIdea/__init__.py | 8 + pytorch_example/RUL/otherIdea/adaRNN.py | 8 + .../RUL/otherIdea/adaRNN/__init__.py | 8 + .../RUL/otherIdea/adaRNN/loss/__init__.py | 8 + pytorch_example/example/__init__.py | 8 + pytorch_example/example/load_data.py | 111 ++++++ pytorch_example/example/loss.py | 127 +++++++ pytorch_example/example/mmd.py | 56 +++ pytorch_example/example/model.py | 109 ++++++ pytorch_example/example/read_data.py | 298 +++++++++++++++ pytorch_example/example/test.py | 226 +++++++++++ pytorch_example/example/train.py | 357 ++++++++++++++++++ 13 files changed, 1332 insertions(+) create mode 100644 pytorch_example/RUL/__init__.py create mode 100644 pytorch_example/RUL/otherIdea/__init__.py create mode 100644 pytorch_example/RUL/otherIdea/adaRNN.py create mode 100644 pytorch_example/RUL/otherIdea/adaRNN/__init__.py create mode 100644 pytorch_example/RUL/otherIdea/adaRNN/loss/__init__.py create mode 100644 pytorch_example/example/__init__.py create mode 100644 pytorch_example/example/load_data.py create mode 100644 pytorch_example/example/loss.py create mode 100644 pytorch_example/example/mmd.py create mode 100644 pytorch_example/example/model.py create mode 100644 pytorch_example/example/read_data.py create mode 100644 pytorch_example/example/test.py create mode 100644 pytorch_example/example/train.py diff --git a/pytorch_example/RUL/__init__.py b/pytorch_example/RUL/__init__.py new file mode 100644 index 0000000..090d185 --- /dev/null +++ b/pytorch_example/RUL/__init__.py @@ -0,0 +1,8 @@ +#-*- encoding:utf-8 -*- + +''' +@Author : dingjiawen +@Date : 2023/11/9 21:32 +@Usage : +@Desc : +''' \ No newline at end of file diff --git a/pytorch_example/RUL/otherIdea/__init__.py b/pytorch_example/RUL/otherIdea/__init__.py new file mode 100644 index 0000000..c4388ef --- /dev/null +++ b/pytorch_example/RUL/otherIdea/__init__.py @@ -0,0 +1,8 @@ +#-*- encoding:utf-8 -*- + +''' +@Author : dingjiawen +@Date : 2023/11/9 21:33 +@Usage : +@Desc : +''' \ No newline at end of file diff --git a/pytorch_example/RUL/otherIdea/adaRNN.py b/pytorch_example/RUL/otherIdea/adaRNN.py new file mode 100644 index 0000000..d96a162 --- /dev/null +++ b/pytorch_example/RUL/otherIdea/adaRNN.py @@ -0,0 +1,8 @@ +#-*- encoding:utf-8 -*- + +''' +@Author : dingjiawen +@Date : 2023/11/9 21:34 +@Usage : +@Desc : +''' \ No newline at end of file diff --git a/pytorch_example/RUL/otherIdea/adaRNN/__init__.py b/pytorch_example/RUL/otherIdea/adaRNN/__init__.py new file mode 100644 index 0000000..d96a162 --- /dev/null +++ b/pytorch_example/RUL/otherIdea/adaRNN/__init__.py @@ -0,0 +1,8 @@ +#-*- encoding:utf-8 -*- + +''' +@Author : dingjiawen +@Date : 2023/11/9 21:34 +@Usage : +@Desc : +''' \ No newline at end of file diff --git a/pytorch_example/RUL/otherIdea/adaRNN/loss/__init__.py b/pytorch_example/RUL/otherIdea/adaRNN/loss/__init__.py new file mode 100644 index 0000000..d96a162 --- /dev/null +++ b/pytorch_example/RUL/otherIdea/adaRNN/loss/__init__.py @@ -0,0 +1,8 @@ +#-*- encoding:utf-8 -*- + +''' +@Author : dingjiawen +@Date : 2023/11/9 21:34 +@Usage : +@Desc : +''' \ No newline at end of file diff --git a/pytorch_example/example/__init__.py b/pytorch_example/example/__init__.py new file mode 100644 index 0000000..0a7fc33 --- /dev/null +++ b/pytorch_example/example/__init__.py @@ -0,0 +1,8 @@ +#-*- encoding:utf-8 -*- + +''' +@Author : dingjiawen +@Date : 2023/11/9 21:28 +@Usage : +@Desc : +''' \ No newline at end of file diff --git a/pytorch_example/example/load_data.py b/pytorch_example/example/load_data.py new file mode 100644 index 0000000..1e4b4eb --- /dev/null +++ b/pytorch_example/example/load_data.py @@ -0,0 +1,111 @@ +""" +@Author: miykah +@Email: miykah@163.com +@FileName: load_data.py +@DateTime: 2022/7/20 16:40 +""" +import matplotlib.pyplot as plt +import torch +import numpy as np +from torch.utils.data import Dataset, DataLoader +import torchvision.transforms as transforms +import random + +'''正常Dataset类''' +class Nor_Dataset(Dataset): + def __init__(self, datas, labels=None): + self.datas = torch.tensor(datas) + if labels is not None: + self.labels = torch.tensor(labels) + else: + self.labels = None + def __getitem__(self, index): + data = self.datas[index] + if self.labels is not None: + label = self.labels[index] + return data, label + return data + def __len__(self): + return len(self.datas) + +'''未标记目标数据的Dataset类''' +class Tar_U_Dataset(Dataset): + def __init__(self, datas): + self.datas = torch.tensor(datas) + def __getitem__(self, index): + data = self.datas[index] + data_bar = data.clone().detach() + data_bar2 = data.clone().detach() + mu = 0 + sigma = 0.1 + '''对未标记目标数据加噪,得到data'和data'' ''' + for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])): + data_bar[0, i] += random.gauss(mu, sigma) + data_bar2[0, j] += random.gauss(mu, sigma) + return data, data_bar, data_bar2 + def __len__(self): + return len(self.datas) + +def draw_signal_img(data, data_bar, data_bar2): + pic = plt.figure(figsize=(12, 6), dpi=100) + plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif'] + plt.rcParams['axes.unicode_minus'] = False + plt.subplot(3, 1, 1) + plt.plot(data, 'b') + plt.subplot(3, 1, 2) + plt.plot(data_bar, 'b') + plt.subplot(3, 1, 3) + plt.plot(data_bar2, 'b') + plt.show() + +def get_dataset(src_condition, tar_condition): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + + src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy') + src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy') + + tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy') + + val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy') + val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy') + + test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy') + test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy') + + src_dataset = Nor_Dataset(src_data, src_label) + tar_dataset = Nor_Dataset(tar_data) + val_dataset = Nor_Dataset(val_data, val_label) + test_dataset = Nor_Dataset(test_data, test_label) + + return src_dataset, tar_dataset, val_dataset, test_dataset + + + + +if __name__ == '__main__': + # pass + # tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy") + # tar_dataset = Nor_Dataset(datas=tar_data) + # tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False) + # + # for batch_idx, (data) in enumerate(tar_loader): + # print(data[0][0]) + # # print(data_bar[0][0]) + # # print(data_bar2[0][0]) + # if batch_idx == 0: + # break + src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy") + src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy") + src_dataset = Nor_Dataset(datas=src_data, labels=src_label) + src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False) + + for batch_idx, (data, label) in enumerate(src_loader): + print(data[10][0]) + print(label[10]) + # print(data_bar[0][0]) + # print(data_bar2[0][0]) + if batch_idx == 0: + break \ No newline at end of file diff --git a/pytorch_example/example/loss.py b/pytorch_example/example/loss.py new file mode 100644 index 0000000..faef623 --- /dev/null +++ b/pytorch_example/example/loss.py @@ -0,0 +1,127 @@ +""" +@Author: miykah +@Email: miykah@163.com +@FileName: loss.py +@DateTime: 2022/7/21 14:31 +""" + +import torch +import numpy as np +import torch.nn.functional as functional +import torch.nn as nn +import example.mmd as mmd + +''' + 计算条件概率分布的差异 + mmdc = mmd_condition +''' +def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2): + src_cls_10 = src_data_mmd1[3 * 0: 3 * 1] + src_cls_11 = src_data_mmd1[3 * 1: 3 * 2] + src_cls_12 = src_data_mmd1[3 * 2: 3 * 3] + src_cls_13 = src_data_mmd1[3 * 3: 3 * 4] + src_cls_14 = src_data_mmd1[3 * 4: 3 * 5] + src_cls_15 = src_data_mmd1[3 * 5: 3 * 6] + src_cls_16 = src_data_mmd1[3 * 6: 3 * 7] + src_cls_17 = src_data_mmd1[3 * 7: 3 * 8] + src_cls_18 = src_data_mmd1[3 * 8: 3 * 9] + src_cls_19 = src_data_mmd1[3 * 9: 3 * 10] + + tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1] + tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2] + tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3] + tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4] + tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5] + tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6] + tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7] + tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8] + tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9] + tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10] + + + src_cls_20 = src_data_mmd2[3 * 0: 3 * 1] + src_cls_21 = src_data_mmd2[3 * 1: 3 * 2] + src_cls_22 = src_data_mmd2[3 * 2: 3 * 3] + src_cls_23 = src_data_mmd2[3 * 3: 3 * 4] + src_cls_24 = src_data_mmd2[3 * 4: 3 * 5] + src_cls_25 = src_data_mmd2[3 * 5: 3 * 6] + src_cls_26 = src_data_mmd2[3 * 6: 3 * 7] + src_cls_27 = src_data_mmd2[3 * 7: 3 * 8] + src_cls_28 = src_data_mmd2[3 * 8: 3 * 9] + src_cls_29 = src_data_mmd2[3 * 9: 3 * 10] + + tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1] + tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2] + tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3] + tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4] + tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5] + tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6] + tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7] + tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8] + tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9] + tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10] + + mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10) + mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11) + mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12) + mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13) + mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14) + mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15) + mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16) + mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17) + mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18) + mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19) + + mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20) + mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21) + mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22) + mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23) + mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24) + mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25) + mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26) + mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27) + mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28) + mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29) + + mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19 + mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29 + + return (mmdc2) / 10 + # return (mmdc1 + mmdc2) / 20 + +'''得到源域每类特征,用于计算mmdc''' +def get_src_mean_feature(src_feature, shot, cls): + src_feature_list = [] + for i in range(cls): + src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0) + src_feature_list.append(src_feature_cls) + return src_feature_list + +def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls): + src_feature_list = get_src_mean_feature(src_feature, shot, cls) + pseudo_label = tar_pseudo_label.cpu().detach().numpy() + mmdc = 0.0 + for i in range(batch_size): + # mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1)) + mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1)) + return mmdc / batch_size + +class BCE(nn.Module): + eps = 1e-7 + def forward(self, prob1, prob2, simi): + P = prob1.mul_(prob2) + P = P.sum(1) + P.mul_(simi).add_(simi.eq(-1).type_as(P)) + neglogP = -P.add_(BCE.eps).log_() + return neglogP.mean() + +class BinaryCrossEntropyLoss(nn.Module): + """ Construct binary cross-entropy loss.""" + eps = 1e-7 + def forward(self, prob): + # ds = torch.ones([bs, 1]).to(device) # domain label for source + # dt = torch.zeros([bs, 1]).to(device) # domain label for target + # di = torch.cat((ds, dt), dim=0).to(device) + # neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps)) + neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps)) + return neglogP.mean() \ No newline at end of file diff --git a/pytorch_example/example/mmd.py b/pytorch_example/example/mmd.py new file mode 100644 index 0000000..65ed811 --- /dev/null +++ b/pytorch_example/example/mmd.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import torch + +# Consider linear time MMD with a linear kernel: +# K(f(x), f(y)) = f(x)^Tf(y) +# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) +# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] +# +# f_of_X: batch_size * k +# f_of_Y: batch_size * k +def mmd_linear(f_of_X, f_of_Y): + delta = f_of_X - f_of_Y + loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1))) + return loss + +def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): + n_samples = int(source.size()[0])+int(target.size()[0]) + total = torch.cat([source, target], dim=0) + total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) + total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) + L2_distance = ((total0-total1)**2).sum(2) + if fix_sigma: + bandwidth = fix_sigma + else: + bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) + bandwidth /= kernel_mul ** (kernel_num // 2) + bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] + kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] + return sum(kernel_val)#/len(kernel_val) + + +def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): + batch_size = int(source.size()[0]) + kernels = guassian_kernel(source, target, + kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) + loss = 0 + for i in range(batch_size): + s1, s2 = i, (i+1)%batch_size + t1, t2 = s1+batch_size, s2+batch_size + loss += kernels[s1, s2] + kernels[t1, t2] + loss -= kernels[s1, t2] + kernels[s2, t1] + return loss / float(batch_size) + +def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): + batch_size = int(source.size()[0]) + kernels = guassian_kernel(source, target, + kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) + XX = kernels[:batch_size, :batch_size] + YY = kernels[batch_size:, batch_size:] + XY = kernels[:batch_size, batch_size:] + YX = kernels[batch_size:, :batch_size] + loss = torch.mean(XX + YY - XY -YX) + return loss + diff --git a/pytorch_example/example/model.py b/pytorch_example/example/model.py new file mode 100644 index 0000000..42af880 --- /dev/null +++ b/pytorch_example/example/model.py @@ -0,0 +1,109 @@ +""" +@Author: miykah +@Email: miykah@163.com +@FileName: model.py +@DateTime: 2022/7/20 21:18 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + +class GradReverse(Function): + def __init__(self): + self.lambd = 1.0 + + def forward(self, x): + return x.view_as(x) + + def backward(self, grad_output): + return (grad_output * - 1.0) + +def grad_reverse(x): + return GradReverse.apply(x) + # return GradReverse(lambd)(x) + +'''特征提取器''' +class Extractor(nn.Module): + def __init__(self): + super(Extractor, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.MaxPool1d(2) # (32 * 1024) + ) + self.conv2 = nn.Sequential( + nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.MaxPool1d(2) # (32 * 512) + ) + self.conv3 = nn.Sequential( + nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.MaxPool1d(2) # (32 * 256) + ) + + def forward(self, src_data, tar_data): + src_data = self.conv1(src_data) + src_data = self.conv2(src_data) + src_feature = self.conv3(src_data) + + tar_data = self.conv1(tar_data) + tar_data = self.conv2(tar_data) + tar_feature = self.conv3(tar_data) + return src_feature, tar_feature + +'''标签分类器''' +class LabelClassifier(nn.Module): + def __init__(self, cls_num): + super(LabelClassifier, self).__init__() + self.fc1 = nn.Sequential( + nn.Flatten(), # (8192,) + nn.Linear(in_features=8192, out_features=256), + ) + self.fc2 = nn.Sequential( + nn.ReLU(), + nn.Linear(in_features=256, out_features=cls_num) + ) + def forward(self, src_feature, tar_feature): + src_data_mmd1 = self.fc1(src_feature) + src_output = self.fc2(src_data_mmd1) + + tar_data_mmd1 = self.fc1(tar_feature) + tar_output = self.fc2(tar_data_mmd1) + return src_data_mmd1, src_output, tar_data_mmd1, tar_output + +'''分类器''' +class DomainClassifier(nn.Module): + def __init__(self, temp=0.05): + super(DomainClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Flatten(), + nn.Linear(in_features=8192, out_features=512), + nn.ReLU(), + nn.Linear(in_features=512, out_features=128), + nn.ReLU(), + nn.Linear(in_features=128, out_features=1), + nn.Sigmoid() + ) + self.temp = temp + def forward(self, x, reverse=False): + if reverse: + x = grad_reverse(x) + output = self.fc(x) + return output + +'''初始化网络权重''' +def weights_init_Extractor(m): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') + nn.init.constant_(m.bias, 0) + +def weights_init_Classifier(m): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) \ No newline at end of file diff --git a/pytorch_example/example/read_data.py b/pytorch_example/example/read_data.py new file mode 100644 index 0000000..2b70b71 --- /dev/null +++ b/pytorch_example/example/read_data.py @@ -0,0 +1,298 @@ +""" +@Author: miykah +@Email: miykah@163.com +@FileName: read_data_cwru.py +@DateTime: 2022/7/20 15:58 +""" +import math +import os +import scipy.io as scio +import numpy as np +import PIL.Image as Image +import matplotlib.pyplot as plt +import torch + +def dat_to_numpy(samples_num, sample_length): + '''根据每个类别需要的样本数和样本长度计算采样步长''' + stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1)) + + conditions = ['A', 'B', 'C'] + for condition in conditions: + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition + save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition + save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition + + if condition == 'A': + start_idx = 1 + end_idx = 5 + # start_idx = 13 + # end_idx = 17 + # start_idx = 25 + # end_idx = 29 + # start_idx = 9 + # end_idx = 13 + elif condition == 'B': + start_idx = 5 + end_idx = 9 + # start_idx = 17 + # end_idx = 21 + # start_idx = 29 + # end_idx = 33 + elif condition == 'C': + start_idx = 9 + end_idx = 13 + # start_idx = 21 + # end_idx = 25 + # start_idx = 33 + # end_idx = 37 + # start_idx = 25 + # end_idx = 29 + + dir_names = os.listdir(root_dir) + print(dir_names) + ''' + 故障类型 + ['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速', + '平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速', + '平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速', + '平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速'] + ''' + if not os.path.exists(save_dir): + os.makedirs(save_dir) + for dir_name in dir_names: + data_list = [] + for i in range(start_idx, end_idx): + if i < 10: + path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat' + else: + path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat' + data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3] + data_list.append(data) + # data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024) + # data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048) + data_one_cls = np.array(data_list).reshape(-1) + n = math.floor((204800 * 4 - sample_length) / stride + 1) + list = [] + for i in range(n): + start = i * stride + end = i * stride + sample_length + list.append(data_one_cls[start: end]) + data_one_cls = np.array(list).reshape(-1, 1, sample_length) + # 打乱数据 + shuffle_ix = np.random.permutation(np.arange(len(data_one_cls))) + data_one_cls = data_one_cls[shuffle_ix] + print(data_one_cls.shape) + np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls) + + +def draw_signal_img(condition): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition + else: + root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition + file_names = os.listdir(root_dir) + pic = plt.figure(figsize=(12, 6), dpi=200) + # plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif'] + # plt.rcParams['axes.unicode_minus'] = False + plt.rc('font',family='Times New Roman') + clses = ['(a)', '(b)', '(c)', '(d)'] + for file_name, i, cls in zip(file_names, range(4), clses): + data = np.load(root_dir + '\\' + file_name)[0].reshape(-1) + print(data.shape, file_name) + plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度 + plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内 + plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内 + plt.subplot(2, 2, i + 1) + plt.title(cls, fontsize=20) + plt.xlabel("Time (s)", fontsize=20) + plt.ylabel("Amplitude (m/s²)", fontsize=20) + plt.xlim(0, 0.16) + plt.ylim(-3, 3) + plt.yticks(np.array([-2, 0, 2]), fontsize=20) + plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20) + plt.plot(data) + plt.show() + +def draw_signal_img2(): + conditions = ['A', 'B', 'C'] + pic = plt.figure(figsize=(12, 6), dpi=600) + plt.text(x=1, y=3, s='A') + i = 1 + for condition in conditions: + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition + file_names = os.listdir(root_dir) + # plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif'] + # plt.rcParams['axes.unicode_minus'] = False + plt.rc('font',family='Times New Roman') + clses = ['(a)', '(b)', '(c)', '(d)', '(e)'] + for file_name, cls in zip(file_names, clses): + data = np.load(root_dir + '\\' + file_name)[3].reshape(-1) + print(data.shape, file_name) + plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度 + plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内 + plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内 + plt.subplot(3, 5, i) + plt.title(cls, fontsize=15) + plt.xlabel("Time (s)", fontsize=15) + plt.ylabel("Amplitude (m/s²)", fontsize=15) + plt.xlim(0, 0.16) + plt.ylim(-5, 5) + plt.yticks(np.array([-3, 0, 3]), fontsize=15) + plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15) + plt.plot(data) + i += 1 + plt.show() + +'''得到源域数据和标签''' +def get_src_data(src_condition, shot): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition + save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition + save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src' + + file_names = os.listdir(root_dir) + # print(file_names) + datas = [] + labels = [] + for i in range(600 // shot): + for file_name, cls in zip(file_names, range(5)): + data = np.load(root_dir + '\\' + file_name) + data_cut = data[shot * i: shot * (i + 1), :, :] + label_cut = np.array([cls] * shot, dtype='int64') + datas.append(data_cut) + labels.append(label_cut) + datas = np.array(datas).reshape(-1, 1, 2048) + labels = np.array(labels).reshape(-1) + print(datas.shape, labels.shape) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + np.save(save_dir + '\\' + 'data.npy', datas) + np.save(save_dir + '\\' + 'label.npy', labels) + +'''得到没有标签的目标域数据''' +def get_tar_data(tar_condition): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar' + file_names = os.listdir(root_dir) + datas = [] + for file_name in file_names: + data = np.load(root_dir + '\\' + file_name) + data_cut = data[0: 600, :, :] + datas.append(data_cut) + datas = np.array(datas).reshape(-1, 1, 2048) + print("datas.shape: {}".format(datas.shape)) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + np.save(save_dir + '\\' + 'data.npy', datas) + +def get_tar_data2(tar_condition, shot): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1' + file_names = os.listdir(root_dir) + datas = [] + labels = [] + for i in range(600 // shot): + for file_name, cls in zip(file_names, range(5)): + data = np.load(root_dir + '\\' + file_name) + data_cut = data[shot * i: shot * (i + 1), :, :] + label_cut = np.array([cls] * shot, dtype='int64') + datas.append(data_cut) + labels.append(label_cut) + datas = np.array(datas).reshape(-1, 1, 2048) + labels = np.array(labels).reshape(-1) + print(datas.shape, labels.shape) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + np.save(save_dir + '\\' + 'data.npy', datas) + np.save(save_dir + '\\' + 'label.npy', labels) + +'''得到验证集(每个类别100个样本)''' +def get_val_data(tar_condition): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val' + file_names = os.listdir(root_dir) + datas = [] + labels = [] + for file_name, cls in zip(file_names, range(5)): + data = np.load(root_dir + '\\' + file_name) + data_cut = data[600: 700, :, :] + label_cut = np.array([cls] * 100, dtype='int64') + datas.append(data_cut) + labels.append(label_cut) + datas = np.array(datas).reshape(-1, 1, 2048) + labels = np.array(labels).reshape(-1) + print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape)) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + np.save(save_dir + '\\' + 'data.npy', datas) + np.save(save_dir + '\\' + 'label.npy', labels) + +'''得到测试集(100个样本)''' +def get_test_data(tar_condition): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition + save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test' + file_names = os.listdir(root_dir) + datas = [] + labels = [] + for file_name, cls in zip(file_names, range(5)): + data = np.load(root_dir + '\\' + file_name) + data_cut = data[700: , :, :] + label_cut = np.array([cls] * 100, dtype='int64') + datas.append(data_cut) + labels.append(label_cut) + datas = np.array(datas).reshape(-1, 1, 2048) + labels = np.array(labels).reshape(-1) + print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape)) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + np.save(save_dir + '\\' + 'data.npy', datas) + np.save(save_dir + '\\' + 'label.npy', labels) + + +if __name__ == '__main__': + # dat_to_numpy(samples_num=800, sample_length=2048) + # draw_signal_img('C') + # draw_signal_img2() + # get_src_data('A', shot=10) + # get_src_data('B', shot=10) + # get_src_data('C', shot=10) + get_tar_data('A') + get_tar_data('B') + get_tar_data('C') + get_val_data('A') + get_val_data('B') + get_val_data('C') + get_test_data('A') + get_test_data('B') + get_test_data('C') + + get_src_data('A', shot=10) + get_src_data('B', shot=10) + get_src_data('C', shot=10) + get_tar_data2('A', shot=10) + get_tar_data2('B', shot=10) + get_tar_data2('C', shot=10) diff --git a/pytorch_example/example/test.py b/pytorch_example/example/test.py new file mode 100644 index 0000000..a6ca108 --- /dev/null +++ b/pytorch_example/example/test.py @@ -0,0 +1,226 @@ +""" +@Author: miykah +@Email: miykah@163.com +@FileName: test.py +@DateTime: 2022/7/9 14:15 +""" + +import numpy as np +import torch +from torch.utils.data import DataLoader +from DDS_GADAN.load_data import get_dataset +from DDS_GADAN.load_data import Nor_Dataset +from DDS_GADAN.model import Extractor, LabelClassifier +from sklearn.metrics import confusion_matrix +from sklearn.manifold import TSNE +import seaborn as sns +import matplotlib.pyplot as plt + +def load_data(tar_condition): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\' + + test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy') + test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy') + + return test_data, test_label + +def tsne_2d_generate(cls, data, labels, pic_title): + tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data) + x, y = tsne2D[:, 0], tsne2D[:, 1] + pic = plt.figure() + ax1 = pic.add_subplot() + ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色,因为标签有9类 + plt.title(pic_title) + plt.show() + +def tsne_2d_generate1(cls, data, labels, pic_title): + parameters = {'figure.dpi': 600, + 'figure.figsize': (4, 3), + 'savefig.dpi': 600, + 'xtick.direction': 'in', + 'ytick.direction': 'in', + 'xtick.labelsize': 10, + 'ytick.labelsize': 10, + 'legend.fontsize': 11.3, + } + plt.rcParams.update(parameters) + plt.rc('font', family='Times New Roman') # 全局字体样式 + tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data) + tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0) + tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min) + s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :] + pic = plt.figure() + # ax1 = pic.add_subplot() + plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色,因为标签有9类 + plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色,因为标签有9类 + plt.title(pic_title, fontsize=10) + # plt.xticks([]) + # plt.yticks([]) + plt.show() + +def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels): + # # 画混淆矩阵 + # confusion = confusion_matrix(true_labels, predict_labels) + # # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis] + # plt.figure(figsize=(6.4,6.4), dpi=100) + # sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens") + # # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues") + # indices = range(len(confusion)) + # classes = ['N', 'IF', 'OF', 'TRC', 'TSP'] + # # for i in range(cls): + # # classes.append(str(i)) + # # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表 + # # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 + # # plt.yticks(indices, classes, rotation=45) + # plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 + # plt.yticks([index + 0.5 for index in indices], classes, rotation=45) + # plt.ylabel('Actual label') + # plt.xlabel('Predicted label') + # plt.title('confusion matrix') + # plt.show() + sns.set(font_scale=1.5) + parameters = {'figure.dpi': 600, + 'figure.figsize': (5, 5), + 'savefig.dpi': 600, + 'xtick.direction': 'in', + 'ytick.direction': 'in', + 'xtick.labelsize': 20, + 'ytick.labelsize': 20, + 'legend.fontsize': 11.3, + } + plt.rcParams.update(parameters) + plt.figure() + plt.rc('font', family='Times New Roman') # 全局字体样式 + # 画混淆矩阵 + confusion = confusion_matrix(true_labels, predict_labels) + # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis] + plt.figure() + # sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens") + sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True) + indices = range(len(confusion)) + classes = ['N', 'IF', 'OF', 'TRC', 'TSP'] + # for i in range(cls): + # classes.append(str(i)) + # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表 + # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 + # plt.yticks(indices, classes, rotation=45) + plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 + plt.yticks([index + 0.5 for index in indices], classes, rotation=45) + plt.ylabel('Actual label', fontsize=20) + plt.xlabel('Predicted label', fontsize=20) + # plt.tight_layout() + plt.show() + + +def test(cls, tar_condition, G_params_path, LC_params_path): + test_data, test_label = load_data(tar_condition) + test_dataset = Nor_Dataset(test_data, test_label) + + batch_size = len(test_dataset) + + test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) + + # 加载网络 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + G = Extractor().to(device) + LC = LabelClassifier(cls_num=cls).to(device) + G.load_state_dict( + torch.load(G_params_path, map_location=device) + ) + LC.load_state_dict( + torch.load(LC_params_path, map_location=device) + ) + # print(net) + # params_num = sum(param.numel() for param in net.parameters_bak()) + # print('参数数量:{}'.format(params_num)) + + test_acc = 0.0 + + G.eval() + LC.eval() + with torch.no_grad(): + for batch_idx, (data, label) in enumerate(test_loader): + data, label = data.to(device), label.to(device) + _, feature = G(data, data) + _, _, _, output = LC(feature, feature) + test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy()) + + predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1) + labels = label.cpu().numpy() + + predictions = output.cpu().detach().numpy() + + tsne_2d_generate(cls, predictions, labels, "output of neural network") + + plot_confusion_matrix_accuracy(cls, labels, predict_labels) + + print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__())) + +def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path): + if torch.cuda.is_available(): + root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + else: + root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + + src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy') + src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy') + + tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy') + tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy') + + src_dataset = Nor_Dataset(src_data, src_label) + tar_dataset = Nor_Dataset(tar_data, tar_label) + src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True) + tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True) + + # 加载网络 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + G = Extractor().to(device) + LC = LabelClassifier(cls_num=cls).to(device) + G.load_state_dict( + torch.load(G_params_path, map_location=device) + ) + LC.load_state_dict( + torch.load(LC_params_path, map_location=device) + ) + # print(net) + # params_num = sum(param.numel() for param in net.parameters_bak()) + # print('参数数量:{}'.format(params_num)) + + test_acc = 0.0 + + G.eval() + LC.eval() + with torch.no_grad(): + for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)): + src_data, src_label = src_data.to(device), src_label.to(device) + tar_data, tar_label = tar_data.to(device), tar_label.to(device) + data = torch.concat((src_data, tar_data), dim=0) + label = torch.concat((src_label, tar_label), dim=0) + _, feature = G(data, data) + _, _, fc1, output = LC(feature, feature) + test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy()) + + predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1) + labels = label.cpu().numpy() + + outputs = output.cpu().detach().numpy() + fc1_outputs = fc1.cpu().detach().numpy() + break + + tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN") + + plot_confusion_matrix_accuracy(cls, labels, predict_labels) + + print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__()))) + +if __name__ == '__main__': + # pass + test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl', + 'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl') + + # test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl', + # 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl') \ No newline at end of file diff --git a/pytorch_example/example/train.py b/pytorch_example/example/train.py new file mode 100644 index 0000000..6a60c6e --- /dev/null +++ b/pytorch_example/example/train.py @@ -0,0 +1,357 @@ +""" +@Author: miykah +@Email: miykah@163.com +@FileName: train.py +@DateTime: 2022/7/20 20:22 +""" +import os +import time +import random +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from example.load_data import get_dataset, draw_signal_img +from example.model import Extractor, LabelClassifier, DomainClassifier, weights_init_Classifier, weights_init_Extractor +from example.loss import cal_mmdc_loss, BinaryCrossEntropyLoss, get_mmdc +import example.mmd as mmd +from example.test import test, test1 +from scipy.spatial.distance import cdist +import math + +def obtain_label(feature, output, bs): + with torch.no_grad(): + all_fea = feature.reshape(bs, -1).float().cpu() + all_output = output.float().cpu() + # all_label = label.float() + all_output = nn.Softmax(dim=1)(all_output) + _, predict = torch.max(all_output, 1) + # accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) + + all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1) + all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t() + all_fea = all_fea.float().cpu().numpy() + + K = all_output.size(1) + aff = all_output.float().cpu().numpy() + initc = aff.transpose().dot(all_fea) + initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) + dd = cdist(all_fea, initc, 'cosine') + pred_label = dd.argmin(axis=1) + # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) + + for round in range(1): + aff = np.eye(K)[pred_label] + initc = aff.transpose().dot(all_fea) + initc = initc / (1e-8 + aff.sum(axis=0)[:, None]) + dd = cdist(all_fea, initc, 'cosine') + pred_label = dd.argmin(axis=1) + # acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea) + + # log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100) + # args.out_file.write(log_str + '\n') + # args.out_file.flush() + # print(log_str + '\n') + return pred_label.astype('int') + + +def train(device, src_condition, tar_condition, cls, epochs, bs, shot, lr, patience, gamma, miu): + '''特征提取器''' + G = Extractor() + G.apply(weights_init_Extractor) + G.to(device) + '''标签分类器''' + LC = LabelClassifier(cls_num=cls) + LC.apply(weights_init_Classifier) + LC.to(device) + '''域分类器''' + DC = DomainClassifier() + DC.apply(weights_init_Classifier) + DC.to(device) + + '''得到数据集''' + src_dataset, tar_dataset, val_dataset, test_dataset \ + = get_dataset(src_condition, tar_condition) + '''DataLoader''' + src_loader = DataLoader(dataset=src_dataset, batch_size=bs, shuffle=False, drop_last=True) + tar_loader = DataLoader(dataset=tar_dataset, batch_size=bs, shuffle=True, drop_last=True) + val_loader = DataLoader(dataset=val_dataset, batch_size=bs, shuffle=True, drop_last=False) + test_loader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False) + + criterion = nn.CrossEntropyLoss().to(device) + BCE = BinaryCrossEntropyLoss().to(device) + + optimizer_g = torch.optim.Adam(G.parameters(), lr=lr) + optimizer_lc = torch.optim.Adam(LC.parameters(), lr=lr) + optimizer_dc = torch.optim.Adam(DC.parameters(), lr=lr) + scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode="min", factor=0.5, patience=patience) + scheduler_lc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_lc, mode="min", factor=0.5, patience=patience) + scheduler_dc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_dc, mode="min", factor=0.5, patience=patience) + + def zero_grad_all(): + optimizer_g.zero_grad() + optimizer_lc.zero_grad() + optimizer_dc.zero_grad() + + src_acc_list = [] + train_loss_list = [] + val_acc_list = [] + val_loss_list = [] + + for epoch in range(epochs): + epoch_start_time = time.time() + src_acc = 0.0 + train_loss = 0.0 + val_acc = 0.0 + val_loss = 0.0 + + G.train() + LC.train() + DC.train() + + for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data)) in zip(enumerate(src_loader), enumerate(tar_loader)): + src_data, src_label = src_data.to(device), src_label.to(device) + tar_data = tar_data.to(device) + zero_grad_all() + + T1 = (int)(0.2 * epochs) + T2 = (int)(0.5 * epochs) + + src_feature, tar_feature = G(src_data, tar_data) + src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature) + if epoch < T1: + pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 + else: + pseudo_label = torch.tensor(obtain_label(tar_feature, tar_output, bs), dtype=torch.int64).cuda() + + # pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 + # mmdc权重 + if epoch < T1: + miu_f = 0 + # elif epoch > T1 and epoch < T2: + # miu_f = miu * (epoch - T1) / (T2 - T1) + else: + miu_f = miu + + # 源数据的交叉熵损失 + loss_src = criterion(src_output, src_label) + # 目标数据的伪标签交叉熵损失 + loss_tar_pseudo = criterion(tar_output, pseudo_label) + # mmd损失 + loss_mmdm = mmd.mmd_rbf_noaccelerate(src_data_mmd1, tar_data_mmd1) + if epoch < T1: + loss_mmdc = 0 + else: + loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls) + # loss_mmdc = cal_mmdc_loss(src_data_mmd1, src_output, tar_data_mmd1, tar_output) + # loss_mmdc = get_mmdc(src_data_mmd1, tar_data_mmd1, pseudo_label, bs, shot, cls) + # loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm + + # 伪标签损失的权重 + # if epoch < T1: + # beta_f = 0 + # elif epoch > T1 and epoch < T2: + # beta_f = beta * (epoch - T1) / (T2 - T1) + # else: + # beta_f = beta + + p = epoch / epochs + lamda = (2 / (1 + math.exp(-10 * p))) - 1 + # gamma = (2 / (1 + math.exp(-10 * p))) - 1 + # miu_f = (2 / (1 + math.exp(-10 * p))) - 1 + + loss_jmmd = miu_f * loss_mmdc + (1 - miu_f) * loss_mmdm + + loss_G_LC = loss_src + gamma * loss_jmmd + # loss_G_LC = loss_src + beta_f * loss_tar_pseudo + gamma * loss_jmmd + loss_G_LC.backward() + optimizer_g.step() + optimizer_lc.step() + zero_grad_all() +#----------------------------------------------- + # 对抗域适应的损失 + src_feature, tar_feature = G(src_data, tar_data) + src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature) + # pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 + # 源数据的交叉熵损失 + loss_src = criterion(src_output, src_label) + # 目标数据的伪标签交叉熵损失 + loss_tar_pseudo = criterion(tar_output, pseudo_label) + + gradient_src = \ + torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True, + only_inputs=True)[0] + gradient_tar = \ + torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True, + only_inputs=True)[0] + gradients_adv = torch.cat((gradient_src, gradient_tar), dim=0) + + domain_label_reverse = DC(gradients_adv, reverse=True) + loss_adv_r = BCE(domain_label_reverse) + loss_G_adv = lamda * loss_adv_r + # 更新特征提取器G参数 + loss_G_adv.backward() + optimizer_g.step() + zero_grad_all() +#--------------------------------------------------------------------- + src_feature, tar_feature = G(src_data, tar_data) + src_data_mmd1, src_output, tar_data_mmd1, tar_output = LC(src_feature, tar_feature) + # pseudo_label = torch.argmax(F.softmax(tar_output, dim=1), dim=1) # 伪标签 + # 源数据的交叉熵损失 + loss_src = criterion(src_output, src_label) + # 目标数据的伪标签交叉熵损失 + loss_tar_pseudo = criterion(tar_output, pseudo_label) + + gradient_src = \ + torch.autograd.grad(outputs=loss_src, inputs=src_feature, create_graph=True, retain_graph=True, + only_inputs=True)[0] + gradient_tar = \ + torch.autograd.grad(outputs=loss_tar_pseudo, inputs=tar_feature, create_graph=True, retain_graph=True, + only_inputs=True)[0] + + gradients = torch.cat((gradient_src, gradient_tar), dim=0) + domain_label = DC(gradients, reverse=False) + loss_adv = BCE(domain_label) + loss_DC = lamda * loss_adv + # 更新域分类器的参数 + loss_DC.backward() + optimizer_dc.step() + zero_grad_all() + + src_acc += np.sum(np.argmax(src_output.cpu().detach().numpy(), axis=1) == src_label.cpu().numpy()) + train_loss += (loss_G_LC + loss_G_adv + loss_DC).item() + + G.eval() + LC.eval() + DC.eval() + with torch.no_grad(): + for batch_idx, (val_data, val_label) in enumerate(val_loader): + val_data, val_label = val_data.to(device), val_label.to(device) + _, val_feature = G(val_data, val_data) + _, _, _, val_output = LC(val_feature, val_feature) + loss = criterion(val_output, val_label) + + val_acc += np.sum(np.argmax(val_output.cpu().detach().numpy(), axis=1) == val_label.cpu().numpy()) + val_loss += loss.item() + + scheduler_g.step(val_loss) + scheduler_lc.step(val_loss) + scheduler_dc.step(val_loss) + + print("[{:03d}/{:03d}] {:2.2f} sec(s) src_acc: {:3.6f} train_loss: {:3.9f} | val_acc: {:3.6f} val_loss: {:3.9f} | Learning rate : {:3.6f}".format( + epoch + 1, epochs, time.time() - epoch_start_time, \ + src_acc / src_dataset.__len__(), train_loss / src_dataset.__len__(), + val_acc / val_dataset.__len__(), val_loss / val_dataset.__len__(), + optimizer_g.state_dict()['param_groups'][0]['lr'])) + + # 保存在验证集上loss最小的模型 + # if val_loss_list.__len__() > 0 and (val_loss / val_dataset.__len__()) < min(val_loss_list): + # 如果精度大于最高精度,则保存 + if val_acc_list.__len__() > 0 : + # if (val_acc / val_dataset.__len__()) >= max(val_acc_list): + if (val_acc / val_dataset.__len__()) > max(val_acc_list) or (val_loss / val_dataset.__len__()) < min(val_loss_list): + print("保存模型最佳模型成功") + G_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/G" + LC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/LC" + DC_path = "parameters_bak/" + src_condition + "_to_" + tar_condition + "/DC" + if not os.path.exists(G_path): + os.makedirs(G_path) + if not os.path.exists(LC_path): + os.makedirs(LC_path) + if not os.path.exists(DC_path): + os.makedirs(DC_path) + # 保存模型参数 + torch.save(G.state_dict(), G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl") + torch.save(LC.state_dict(), LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl") + torch.save(DC.state_dict(), DC_path + "/DC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl") + + src_acc_list.append(src_acc / src_dataset.__len__()) + train_loss_list.append(train_loss / src_dataset.__len__()) + val_acc_list.append(val_acc / val_dataset.__len__()) + val_loss_list.append(val_loss / val_dataset.__len__()) + + '''保存的模型参数的路径''' + G_params_path = G_path + "/G_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl" + LC_params_path = LC_path + "/LC_" + "shot" + str(shot) + "_epoch" + str(epochs) + "_lr" + str(lr) + "_miu" + str(miu) + ".pkl" + + + from matplotlib import rcParams + + config = { + "font.family": 'Times New Roman', # 设置字体类型 + "axes.unicode_minus": False, # 解决负号无法显示的问题 + "axes.labelsize": 13 + } + rcParams.update(config) + + # pic1 = plt.figure(figsize=(8, 6), dpi=200) + # plt.subplot(211) + # plt.plot(np.arange(1, epochs + 1), src_acc_list, 'b', label='TrainAcc') + # plt.plot(np.arange(1, epochs + 1), val_acc_list, 'r', label='ValAcc') + # plt.ylim(0.3, 1.0) # 设置y轴范围 + # plt.title('Training & Validation accuracy') + # plt.xlabel('epoch') + # plt.ylabel('accuracy') + # plt.legend(loc='lower right') + # plt.grid(alpha=0.4) + # + # plt.subplot(212) + # plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='TrainLoss') + # plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='ValLoss') + # plt.ylim(0, 0.08) # 设置y轴范围 + # plt.title('Training & Validation loss') + # plt.xlabel('epoch') + # plt.ylabel('loss') + # plt.legend(loc='upper right') + # plt.grid(alpha=0.4) + + pic1 = plt.figure(figsize=(12, 6), dpi=200) + + plt.plot(np.arange(1, epochs + 1), train_loss_list, 'b', label='Training Loss') + plt.plot(np.arange(1, epochs + 1), val_loss_list, 'r', label='Validation Loss') + plt.ylim(0, 0.08) # 设置y轴范围 + plt.title('Training & Validation loss') + plt.xlabel('epoch') + plt.ylabel('loss') + plt.legend(loc='upper right') + plt.grid(alpha=0.4) + + # 获取当前时间戳 + timestamp = int(time.time()) + + # 将时间戳转换为字符串 + timestamp_str = str(timestamp) + plt.savefig(timestamp_str, dpi=200) + + return G_params_path, LC_params_path + + +if __name__ == '__main__': + + begin = time.time() + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + seed = 2 + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + src_condition = 'A' + tar_condition = 'B' + '''训练''' + G_params_path, LC_params_path = train(device, src_condition, tar_condition, cls=5, + epochs=200, bs=50, shot=10, lr=0.002, patience=40, gamma=1, miu=0.5) + + end = time.time() + + '''测试''' + # test1(5, src_condition, tar_condition, G_params_path, LC_params_path) + test(5, tar_condition, G_params_path, LC_params_path) + + print("训练耗时:{:3.2f}s".format(end - begin)) \ No newline at end of file