#-*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/15 14:46 @Usage : @Desc : ''' import torch import torch.nn as nn import math import numpy as np import torch.nn.functional as F from torch.autograd import Function class ReverseLayerF(Function): @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x) @staticmethod def backward(ctx, grad_output): output = grad_output.neg() * ctx.alpha return output, None class Discriminator(nn.Module): def __init__(self, input_dim=256, hidden_dim=256): super(Discriminator, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.dis1 = nn.Linear(input_dim, hidden_dim) self.dis2 = nn.Linear(hidden_dim, 1) def forward(self, x): x = F.relu(self.dis1(x)) x = self.dis2(x) x = torch.sigmoid(x) return x def adv(source, target, input_dim=256, hidden_dim=512): domain_loss = nn.BCELoss() # !!! Pay attention to .cuda !!! adv_net = Discriminator(input_dim, hidden_dim) domain_src = torch.ones(len(source)) # 源域的标签 domain_tar = torch.zeros(len(target)) # 目标域的标签 domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1) reverse_src = ReverseLayerF.apply(source, 1) reverse_tar = ReverseLayerF.apply(target, 1) pred_src = adv_net(reverse_src) pred_tar = adv_net(reverse_tar) loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar) loss = loss_s + loss_t return loss