109 lines
3.2 KiB
Python
109 lines
3.2 KiB
Python
"""
|
|
@Author: miykah
|
|
@Email: miykah@163.com
|
|
@FileName: model.py
|
|
@DateTime: 2022/7/20 21:18
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Function
|
|
|
|
class GradReverse(Function):
|
|
def __init__(self):
|
|
self.lambd = 1.0
|
|
|
|
def forward(self, x):
|
|
return x.view_as(x)
|
|
|
|
def backward(self, grad_output):
|
|
return (grad_output * - 1.0)
|
|
|
|
def grad_reverse(x):
|
|
return GradReverse.apply(x)
|
|
# return GradReverse(lambd)(x)
|
|
|
|
'''特征提取器'''
|
|
class Extractor(nn.Module):
|
|
def __init__(self):
|
|
super(Extractor, self).__init__()
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'),
|
|
nn.BatchNorm1d(32),
|
|
nn.ReLU(),
|
|
nn.MaxPool1d(2) # (32 * 1024)
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
|
|
nn.BatchNorm1d(32),
|
|
nn.ReLU(),
|
|
nn.MaxPool1d(2) # (32 * 512)
|
|
)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
|
|
nn.BatchNorm1d(32),
|
|
nn.ReLU(),
|
|
nn.MaxPool1d(2) # (32 * 256)
|
|
)
|
|
|
|
def forward(self, src_data, tar_data):
|
|
src_data = self.conv1(src_data)
|
|
src_data = self.conv2(src_data)
|
|
src_feature = self.conv3(src_data)
|
|
|
|
tar_data = self.conv1(tar_data)
|
|
tar_data = self.conv2(tar_data)
|
|
tar_feature = self.conv3(tar_data)
|
|
return src_feature, tar_feature
|
|
|
|
'''标签分类器'''
|
|
class LabelClassifier(nn.Module):
|
|
def __init__(self, cls_num):
|
|
super(LabelClassifier, self).__init__()
|
|
self.fc1 = nn.Sequential(
|
|
nn.Flatten(), # (8192,)
|
|
nn.Linear(in_features=8192, out_features=256),
|
|
)
|
|
self.fc2 = nn.Sequential(
|
|
nn.ReLU(),
|
|
nn.Linear(in_features=256, out_features=cls_num)
|
|
)
|
|
def forward(self, src_feature, tar_feature):
|
|
src_data_mmd1 = self.fc1(src_feature)
|
|
src_output = self.fc2(src_data_mmd1)
|
|
|
|
tar_data_mmd1 = self.fc1(tar_feature)
|
|
tar_output = self.fc2(tar_data_mmd1)
|
|
return src_data_mmd1, src_output, tar_data_mmd1, tar_output
|
|
|
|
'''分类器'''
|
|
class DomainClassifier(nn.Module):
|
|
def __init__(self, temp=0.05):
|
|
super(DomainClassifier, self).__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(in_features=8192, out_features=512),
|
|
nn.ReLU(),
|
|
nn.Linear(in_features=512, out_features=128),
|
|
nn.ReLU(),
|
|
nn.Linear(in_features=128, out_features=1),
|
|
nn.Sigmoid()
|
|
)
|
|
self.temp = temp
|
|
def forward(self, x, reverse=False):
|
|
if reverse:
|
|
x = grad_reverse(x)
|
|
output = self.fc(x)
|
|
return output
|
|
|
|
'''初始化网络权重'''
|
|
def weights_init_Extractor(m):
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def weights_init_Classifier(m):
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_normal_(m.weight)
|
|
nn.init.constant_(m.bias, 0) |