self_example/pytorch_example/RUL/otherIdea/adaRNN/loss_transfer.py

65 lines
1.9 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/15 14:45
@Usage :
@Desc :
'''
from RUL.baseModel.loss import adv_loss, coral, kl_js, mmd, mutual_info, cos, pair_dist
class TransferLoss(object):
def __init__(self, loss_type='cosine', input_dim=512):
"""
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
"""
self.loss_type = loss_type
self.input_dim = input_dim
def compute(self, X, Y):
"""Compute adaptation loss
Arguments:
X {tensor} -- source matrix
Y {tensor} -- target matrix
Returns:
[tensor] -- transfer loss
"""
if self.loss_type == 'mmd_lin' or self.loss_type == 'mmd':
mmdloss = mmd.MMD_loss(kernel_type='linear')
loss = mmdloss(X, Y)
elif self.loss_type == 'coral':
loss = coral.CORAL(X, Y)
elif self.loss_type == 'cosine' or self.loss_type == 'cos':
loss = 1 - cos.cosine(X, Y)
elif self.loss_type == 'kl':
loss = kl_js.kl_div(X, Y)
elif self.loss_type == 'js':
loss = kl_js.js(X, Y)
elif self.loss_type == 'mine':
mine_model = mutual_info.Mine_estimator(
input_dim=self.input_dim, hidden_dim=60)
loss = mine_model(X, Y)
elif self.loss_type == 'adv':
loss = adv_loss.adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
elif self.loss_type == 'mmd_rbf':
mmdloss = mmd.MMD_loss(kernel_type='rbf')
loss = mmdloss(X, Y)
elif self.loss_type == 'pairwise':
pair_mat = pair_dist.pairwise_dist(X, Y)
import torch
loss = torch.norm(pair_mat)
return loss
if __name__ == "__main__":
import torch
trans_loss = TransferLoss('adv')
a = (torch.randn(5, 512) * 10)
b = (torch.randn(5, 512) * 10)
print(trans_loss.compute(a, b))