127 lines
5.0 KiB
Python
127 lines
5.0 KiB
Python
"""
|
||
@Author: miykah
|
||
@Email: miykah@163.com
|
||
@FileName: loss.py
|
||
@DateTime: 2022/7/21 14:31
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import torch.nn.functional as functional
|
||
import torch.nn as nn
|
||
import example.mmd as mmd
|
||
|
||
'''
|
||
计算条件概率分布的差异
|
||
mmdc = mmd_condition
|
||
'''
|
||
def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2):
|
||
src_cls_10 = src_data_mmd1[3 * 0: 3 * 1]
|
||
src_cls_11 = src_data_mmd1[3 * 1: 3 * 2]
|
||
src_cls_12 = src_data_mmd1[3 * 2: 3 * 3]
|
||
src_cls_13 = src_data_mmd1[3 * 3: 3 * 4]
|
||
src_cls_14 = src_data_mmd1[3 * 4: 3 * 5]
|
||
src_cls_15 = src_data_mmd1[3 * 5: 3 * 6]
|
||
src_cls_16 = src_data_mmd1[3 * 6: 3 * 7]
|
||
src_cls_17 = src_data_mmd1[3 * 7: 3 * 8]
|
||
src_cls_18 = src_data_mmd1[3 * 8: 3 * 9]
|
||
src_cls_19 = src_data_mmd1[3 * 9: 3 * 10]
|
||
|
||
tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1]
|
||
tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2]
|
||
tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3]
|
||
tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4]
|
||
tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5]
|
||
tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6]
|
||
tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7]
|
||
tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8]
|
||
tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9]
|
||
tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10]
|
||
|
||
|
||
src_cls_20 = src_data_mmd2[3 * 0: 3 * 1]
|
||
src_cls_21 = src_data_mmd2[3 * 1: 3 * 2]
|
||
src_cls_22 = src_data_mmd2[3 * 2: 3 * 3]
|
||
src_cls_23 = src_data_mmd2[3 * 3: 3 * 4]
|
||
src_cls_24 = src_data_mmd2[3 * 4: 3 * 5]
|
||
src_cls_25 = src_data_mmd2[3 * 5: 3 * 6]
|
||
src_cls_26 = src_data_mmd2[3 * 6: 3 * 7]
|
||
src_cls_27 = src_data_mmd2[3 * 7: 3 * 8]
|
||
src_cls_28 = src_data_mmd2[3 * 8: 3 * 9]
|
||
src_cls_29 = src_data_mmd2[3 * 9: 3 * 10]
|
||
|
||
tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1]
|
||
tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2]
|
||
tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3]
|
||
tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4]
|
||
tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5]
|
||
tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6]
|
||
tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7]
|
||
tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8]
|
||
tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9]
|
||
tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10]
|
||
|
||
mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10)
|
||
mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11)
|
||
mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12)
|
||
mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13)
|
||
mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14)
|
||
mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15)
|
||
mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16)
|
||
mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17)
|
||
mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18)
|
||
mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19)
|
||
|
||
mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20)
|
||
mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21)
|
||
mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22)
|
||
mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23)
|
||
mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24)
|
||
mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25)
|
||
mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26)
|
||
mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27)
|
||
mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28)
|
||
mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29)
|
||
|
||
mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19
|
||
mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29
|
||
|
||
return (mmdc2) / 10
|
||
# return (mmdc1 + mmdc2) / 20
|
||
|
||
'''得到源域每类特征,用于计算mmdc'''
|
||
def get_src_mean_feature(src_feature, shot, cls):
|
||
src_feature_list = []
|
||
for i in range(cls):
|
||
src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0)
|
||
src_feature_list.append(src_feature_cls)
|
||
return src_feature_list
|
||
|
||
def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls):
|
||
src_feature_list = get_src_mean_feature(src_feature, shot, cls)
|
||
pseudo_label = tar_pseudo_label.cpu().detach().numpy()
|
||
mmdc = 0.0
|
||
for i in range(batch_size):
|
||
# mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
|
||
mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
|
||
return mmdc / batch_size
|
||
|
||
class BCE(nn.Module):
|
||
eps = 1e-7
|
||
def forward(self, prob1, prob2, simi):
|
||
P = prob1.mul_(prob2)
|
||
P = P.sum(1)
|
||
P.mul_(simi).add_(simi.eq(-1).type_as(P))
|
||
neglogP = -P.add_(BCE.eps).log_()
|
||
return neglogP.mean()
|
||
|
||
class BinaryCrossEntropyLoss(nn.Module):
|
||
""" Construct binary cross-entropy loss."""
|
||
eps = 1e-7
|
||
def forward(self, prob):
|
||
# ds = torch.ones([bs, 1]).to(device) # domain label for source
|
||
# dt = torch.zeros([bs, 1]).to(device) # domain label for target
|
||
# di = torch.cat((ds, dt), dim=0).to(device)
|
||
# neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps))
|
||
neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps))
|
||
return neglogP.mean() |