self_example/pytorch_example/example/model.py

109 lines
3.2 KiB
Python

"""
@Author: miykah
@Email: miykah@163.com
@FileName: model.py
@DateTime: 2022/7/20 21:18
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class GradReverse(Function):
def __init__(self):
self.lambd = 1.0
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * - 1.0)
def grad_reverse(x):
return GradReverse.apply(x)
# return GradReverse(lambd)(x)
'''特征提取器'''
class Extractor(nn.Module):
def __init__(self):
super(Extractor, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=32, kernel_size=13, padding='same'),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2) # (32 * 1024)
)
self.conv2 = nn.Sequential(
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2) # (32 * 512)
)
self.conv3 = nn.Sequential(
nn.Conv1d(in_channels=32, out_channels=32, kernel_size=13, padding='same'),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2) # (32 * 256)
)
def forward(self, src_data, tar_data):
src_data = self.conv1(src_data)
src_data = self.conv2(src_data)
src_feature = self.conv3(src_data)
tar_data = self.conv1(tar_data)
tar_data = self.conv2(tar_data)
tar_feature = self.conv3(tar_data)
return src_feature, tar_feature
'''标签分类器'''
class LabelClassifier(nn.Module):
def __init__(self, cls_num):
super(LabelClassifier, self).__init__()
self.fc1 = nn.Sequential(
nn.Flatten(), # (8192,)
nn.Linear(in_features=8192, out_features=256),
)
self.fc2 = nn.Sequential(
nn.ReLU(),
nn.Linear(in_features=256, out_features=cls_num)
)
def forward(self, src_feature, tar_feature):
src_data_mmd1 = self.fc1(src_feature)
src_output = self.fc2(src_data_mmd1)
tar_data_mmd1 = self.fc1(tar_feature)
tar_output = self.fc2(tar_data_mmd1)
return src_data_mmd1, src_output, tar_data_mmd1, tar_output
'''分类器'''
class DomainClassifier(nn.Module):
def __init__(self, temp=0.05):
super(DomainClassifier, self).__init__()
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=8192, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=1),
nn.Sigmoid()
)
self.temp = temp
def forward(self, x, reverse=False):
if reverse:
x = grad_reverse(x)
output = self.fc(x)
return output
'''初始化网络权重'''
def weights_init_Extractor(m):
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
nn.init.constant_(m.bias, 0)
def weights_init_Classifier(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)