""" @Author: miykah @Email: miykah@163.com @FileName: model.py @DateTime: 2022/7/20 21:18 """ import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function class GradReverse(Function): def __init__(self): self.lambd = 1.0 def forward(self, x): return x.view_as(x) def backward(self, grad_output): return (grad_output * - 1.0) def grad_reverse(x): return GradReverse.apply(x) # return GradReverse(lambd)(x) '''特征提取器''' class Extractor(nn.Module): def __init__(self): super(Extractor, self).__init__() self.conv1 = nn.Sequential( nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2) # (32 * 1024) ) self.conv2 = nn.Sequential( nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2) # (32 * 512) ) self.conv3 = nn.Sequential( nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2) # (32 * 256) ) def forward(self, src_data, tar_data): src_data = self.conv1(src_data) src_data = self.conv2(src_data) src_feature = self.conv3(src_data) tar_data = self.conv1(tar_data) tar_data = self.conv2(tar_data) tar_feature = self.conv3(tar_data) return src_feature, tar_feature '''标签分类器''' class LabelClassifier(nn.Module): def __init__(self, cls_num): super(LabelClassifier, self).__init__() self.fc1 = nn.Sequential( nn.Flatten(), # (8192,) nn.Linear(in_features=8192, out_features=256), ) self.fc2 = nn.Sequential( nn.ReLU(), nn.Linear(in_features=256, out_features=cls_num) ) def forward(self, src_feature, tar_feature): src_data_mmd1 = self.fc1(src_feature) src_output = self.fc2(src_data_mmd1) tar_data_mmd1 = self.fc1(tar_feature) tar_output = self.fc2(tar_data_mmd1) return src_data_mmd1, src_output, tar_data_mmd1, tar_output '''分类器''' class DomainClassifier(nn.Module): def __init__(self, temp=0.05): super(DomainClassifier, self).__init__() self.fc = nn.Sequential( nn.Flatten(), nn.Linear(in_features=8192, out_features=512), nn.ReLU(), nn.Linear(in_features=512, out_features=128), nn.ReLU(), nn.Linear(in_features=128, out_features=1), nn.Sigmoid() ) self.temp = temp def forward(self, x, reverse=False): if reverse: x = grad_reverse(x) output = self.fc(x) return output '''初始化网络权重''' def weights_init_Extractor(m): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') nn.init.constant_(m.bias, 0) def weights_init_Classifier(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0)