57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
#-*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/15 14:46
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Function
|
|
|
|
class ReverseLayerF(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, x, alpha):
|
|
ctx.alpha = alpha
|
|
return x.view_as(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
output = grad_output.neg() * ctx.alpha
|
|
return output, None
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, input_dim=256, hidden_dim=256):
|
|
super(Discriminator, self).__init__()
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.dis1 = nn.Linear(input_dim, hidden_dim)
|
|
self.dis2 = nn.Linear(hidden_dim, 1)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.dis1(x))
|
|
x = self.dis2(x)
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
|
|
def adv(source, target, input_dim=256, hidden_dim=512):
|
|
domain_loss = nn.BCELoss()
|
|
# !!! Pay attention to .cuda !!!
|
|
adv_net = Discriminator(input_dim, hidden_dim)
|
|
domain_src = torch.ones(len(source)) # 源域的标签
|
|
domain_tar = torch.zeros(len(target)) # 目标域的标签
|
|
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
|
reverse_src = ReverseLayerF.apply(source, 1)
|
|
reverse_tar = ReverseLayerF.apply(target, 1)
|
|
pred_src = adv_net(reverse_src)
|
|
pred_tar = adv_net(reverse_tar)
|
|
loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)
|
|
loss = loss_s + loss_t
|
|
return loss |