# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/15 14:45 @Usage : @Desc : ''' from RUL.baseModel.loss import adv_loss, coral, kl_js, mmd, mutual_info, cos, pair_dist class TransferLoss(object): def __init__(self, loss_type='cosine', input_dim=512): """ Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv """ self.loss_type = loss_type self.input_dim = input_dim def compute(self, X, Y): """Compute adaptation loss Arguments: X {tensor} -- source matrix Y {tensor} -- target matrix Returns: [tensor] -- transfer loss """ if self.loss_type == 'mmd_lin' or self.loss_type == 'mmd': mmdloss = mmd.MMD_loss(kernel_type='linear') loss = mmdloss(X, Y) elif self.loss_type == 'coral': loss = coral.CORAL(X, Y) elif self.loss_type == 'cosine' or self.loss_type == 'cos': loss = 1 - cos.cosine(X, Y) elif self.loss_type == 'kl': loss = kl_js.kl_div(X, Y) elif self.loss_type == 'js': loss = kl_js.js(X, Y) elif self.loss_type == 'mine': mine_model = mutual_info.Mine_estimator( input_dim=self.input_dim, hidden_dim=60) loss = mine_model(X, Y) elif self.loss_type == 'adv': loss = adv_loss.adv(X, Y, input_dim=self.input_dim, hidden_dim=32) elif self.loss_type == 'mmd_rbf': mmdloss = mmd.MMD_loss(kernel_type='rbf') loss = mmdloss(X, Y) elif self.loss_type == 'pairwise': pair_mat = pair_dist.pairwise_dist(X, Y) import torch loss = torch.norm(pair_mat) return loss if __name__ == "__main__": import torch trans_loss = TransferLoss('adv') a = (torch.randn(5, 512) * 10) b = (torch.randn(5, 512) * 10) print(trans_loss.compute(a, b))