""" @Author: miykah @Email: miykah@163.com @FileName: loss.py @DateTime: 2022/7/21 14:31 """ import torch import numpy as np import torch.nn.functional as functional import torch.nn as nn import example.mmd as mmd ''' 计算条件概率分布的差异 mmdc = mmd_condition ''' def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2): src_cls_10 = src_data_mmd1[3 * 0: 3 * 1] src_cls_11 = src_data_mmd1[3 * 1: 3 * 2] src_cls_12 = src_data_mmd1[3 * 2: 3 * 3] src_cls_13 = src_data_mmd1[3 * 3: 3 * 4] src_cls_14 = src_data_mmd1[3 * 4: 3 * 5] src_cls_15 = src_data_mmd1[3 * 5: 3 * 6] src_cls_16 = src_data_mmd1[3 * 6: 3 * 7] src_cls_17 = src_data_mmd1[3 * 7: 3 * 8] src_cls_18 = src_data_mmd1[3 * 8: 3 * 9] src_cls_19 = src_data_mmd1[3 * 9: 3 * 10] tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1] tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2] tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3] tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4] tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5] tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6] tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7] tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8] tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9] tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10] src_cls_20 = src_data_mmd2[3 * 0: 3 * 1] src_cls_21 = src_data_mmd2[3 * 1: 3 * 2] src_cls_22 = src_data_mmd2[3 * 2: 3 * 3] src_cls_23 = src_data_mmd2[3 * 3: 3 * 4] src_cls_24 = src_data_mmd2[3 * 4: 3 * 5] src_cls_25 = src_data_mmd2[3 * 5: 3 * 6] src_cls_26 = src_data_mmd2[3 * 6: 3 * 7] src_cls_27 = src_data_mmd2[3 * 7: 3 * 8] src_cls_28 = src_data_mmd2[3 * 8: 3 * 9] src_cls_29 = src_data_mmd2[3 * 9: 3 * 10] tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1] tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2] tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3] tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4] tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5] tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6] tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7] tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8] tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9] tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10] mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10) mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11) mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12) mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13) mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14) mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15) mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16) mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17) mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18) mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19) mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20) mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21) mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22) mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23) mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24) mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25) mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26) mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27) mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28) mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29) mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19 mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29 return (mmdc2) / 10 # return (mmdc1 + mmdc2) / 20 '''得到源域每类特征,用于计算mmdc''' def get_src_mean_feature(src_feature, shot, cls): src_feature_list = [] for i in range(cls): src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0) src_feature_list.append(src_feature_cls) return src_feature_list def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls): src_feature_list = get_src_mean_feature(src_feature, shot, cls) pseudo_label = tar_pseudo_label.cpu().detach().numpy() mmdc = 0.0 for i in range(batch_size): # mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1)) mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1)) return mmdc / batch_size class BCE(nn.Module): eps = 1e-7 def forward(self, prob1, prob2, simi): P = prob1.mul_(prob2) P = P.sum(1) P.mul_(simi).add_(simi.eq(-1).type_as(P)) neglogP = -P.add_(BCE.eps).log_() return neglogP.mean() class BinaryCrossEntropyLoss(nn.Module): """ Construct binary cross-entropy loss.""" eps = 1e-7 def forward(self, prob): # ds = torch.ones([bs, 1]).to(device) # domain label for source # dt = torch.zeros([bs, 1]).to(device) # domain label for target # di = torch.cat((ds, dt), dim=0).to(device) # neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps)) neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps)) return neglogP.mean()