65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/15 14:45
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
from RUL.baseModel.loss import adv_loss, coral, kl_js, mmd, mutual_info, cos, pair_dist
|
|
|
|
|
|
class TransferLoss(object):
|
|
def __init__(self, loss_type='cosine', input_dim=512):
|
|
"""
|
|
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
|
"""
|
|
self.loss_type = loss_type
|
|
self.input_dim = input_dim
|
|
|
|
def compute(self, X, Y):
|
|
"""Compute adaptation loss
|
|
|
|
Arguments:
|
|
X {tensor} -- source matrix
|
|
Y {tensor} -- target matrix
|
|
|
|
Returns:
|
|
[tensor] -- transfer loss
|
|
"""
|
|
if self.loss_type == 'mmd_lin' or self.loss_type == 'mmd':
|
|
mmdloss = mmd.MMD_loss(kernel_type='linear')
|
|
loss = mmdloss(X, Y)
|
|
elif self.loss_type == 'coral':
|
|
loss = coral.CORAL(X, Y)
|
|
elif self.loss_type == 'cosine' or self.loss_type == 'cos':
|
|
loss = 1 - cos.cosine(X, Y)
|
|
elif self.loss_type == 'kl':
|
|
loss = kl_js.kl_div(X, Y)
|
|
elif self.loss_type == 'js':
|
|
loss = kl_js.js(X, Y)
|
|
elif self.loss_type == 'mine':
|
|
mine_model = mutual_info.Mine_estimator(
|
|
input_dim=self.input_dim, hidden_dim=60)
|
|
loss = mine_model(X, Y)
|
|
elif self.loss_type == 'adv':
|
|
loss = adv_loss.adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
|
elif self.loss_type == 'mmd_rbf':
|
|
mmdloss = mmd.MMD_loss(kernel_type='rbf')
|
|
loss = mmdloss(X, Y)
|
|
elif self.loss_type == 'pairwise':
|
|
pair_mat = pair_dist.pairwise_dist(X, Y)
|
|
import torch
|
|
loss = torch.norm(pair_mat)
|
|
|
|
return loss
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import torch
|
|
|
|
trans_loss = TransferLoss('adv')
|
|
a = (torch.randn(5, 512) * 10)
|
|
b = (torch.randn(5, 512) * 10)
|
|
print(trans_loss.compute(a, b))
|