self_example/pytorch_example/RUL/baseModel/loss/adv_loss.py

57 lines
1.6 KiB
Python

#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/15 14:46
@Usage :
@Desc :
'''
import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
from torch.autograd import Function
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class Discriminator(nn.Module):
def __init__(self, input_dim=256, hidden_dim=256):
super(Discriminator, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dis1 = nn.Linear(input_dim, hidden_dim)
self.dis2 = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = F.relu(self.dis1(x))
x = self.dis2(x)
x = torch.sigmoid(x)
return x
def adv(source, target, input_dim=256, hidden_dim=512):
domain_loss = nn.BCELoss()
# !!! Pay attention to .cuda !!!
adv_net = Discriminator(input_dim, hidden_dim)
domain_src = torch.ones(len(source)) # 源域的标签
domain_tar = torch.zeros(len(target)) # 目标域的标签
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
reverse_src = ReverseLayerF.apply(source, 1)
reverse_tar = ReverseLayerF.apply(target, 1)
pred_src = adv_net(reverse_src)
pred_tar = adv_net(reverse_tar)
loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)
loss = loss_s + loss_t
return loss