self_example/pytorch_example/example/loss.py

127 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
@Author: miykah
@Email: miykah@163.com
@FileName: loss.py
@DateTime: 2022/7/21 14:31
"""
import torch
import numpy as np
import torch.nn.functional as functional
import torch.nn as nn
import example.mmd as mmd
'''
计算条件概率分布的差异
mmdc = mmd_condition
'''
def cal_mmdc_loss(src_data_mmd1, src_data_mmd2, tar_data_mmd1, tar_data_mmd2):
src_cls_10 = src_data_mmd1[3 * 0: 3 * 1]
src_cls_11 = src_data_mmd1[3 * 1: 3 * 2]
src_cls_12 = src_data_mmd1[3 * 2: 3 * 3]
src_cls_13 = src_data_mmd1[3 * 3: 3 * 4]
src_cls_14 = src_data_mmd1[3 * 4: 3 * 5]
src_cls_15 = src_data_mmd1[3 * 5: 3 * 6]
src_cls_16 = src_data_mmd1[3 * 6: 3 * 7]
src_cls_17 = src_data_mmd1[3 * 7: 3 * 8]
src_cls_18 = src_data_mmd1[3 * 8: 3 * 9]
src_cls_19 = src_data_mmd1[3 * 9: 3 * 10]
tar_cls_10 = tar_data_mmd1[3 * 0: 3 * 1]
tar_cls_11 = tar_data_mmd1[3 * 1: 3 * 2]
tar_cls_12 = tar_data_mmd1[3 * 2: 3 * 3]
tar_cls_13 = tar_data_mmd1[3 * 3: 3 * 4]
tar_cls_14 = tar_data_mmd1[3 * 4: 3 * 5]
tar_cls_15 = tar_data_mmd1[3 * 5: 3 * 6]
tar_cls_16 = tar_data_mmd1[3 * 6: 3 * 7]
tar_cls_17 = tar_data_mmd1[3 * 7: 3 * 8]
tar_cls_18 = tar_data_mmd1[3 * 8: 3 * 9]
tar_cls_19 = tar_data_mmd1[3 * 9: 3 * 10]
src_cls_20 = src_data_mmd2[3 * 0: 3 * 1]
src_cls_21 = src_data_mmd2[3 * 1: 3 * 2]
src_cls_22 = src_data_mmd2[3 * 2: 3 * 3]
src_cls_23 = src_data_mmd2[3 * 3: 3 * 4]
src_cls_24 = src_data_mmd2[3 * 4: 3 * 5]
src_cls_25 = src_data_mmd2[3 * 5: 3 * 6]
src_cls_26 = src_data_mmd2[3 * 6: 3 * 7]
src_cls_27 = src_data_mmd2[3 * 7: 3 * 8]
src_cls_28 = src_data_mmd2[3 * 8: 3 * 9]
src_cls_29 = src_data_mmd2[3 * 9: 3 * 10]
tar_cls_20 = tar_data_mmd2[3 * 0: 3 * 1]
tar_cls_21 = tar_data_mmd2[3 * 1: 3 * 2]
tar_cls_22 = tar_data_mmd2[3 * 2: 3 * 3]
tar_cls_23 = tar_data_mmd2[3 * 3: 3 * 4]
tar_cls_24 = tar_data_mmd2[3 * 4: 3 * 5]
tar_cls_25 = tar_data_mmd2[3 * 5: 3 * 6]
tar_cls_26 = tar_data_mmd2[3 * 6: 3 * 7]
tar_cls_27 = tar_data_mmd2[3 * 7: 3 * 8]
tar_cls_28 = tar_data_mmd2[3 * 8: 3 * 9]
tar_cls_29 = tar_data_mmd2[3 * 9: 3 * 10]
mmd_10 = mmd.mmd_linear(src_cls_10, tar_cls_10)
mmd_11 = mmd.mmd_linear(src_cls_11, tar_cls_11)
mmd_12 = mmd.mmd_linear(src_cls_12, tar_cls_12)
mmd_13 = mmd.mmd_linear(src_cls_13, tar_cls_13)
mmd_14 = mmd.mmd_linear(src_cls_14, tar_cls_14)
mmd_15 = mmd.mmd_linear(src_cls_15, tar_cls_15)
mmd_16 = mmd.mmd_linear(src_cls_16, tar_cls_16)
mmd_17 = mmd.mmd_linear(src_cls_17, tar_cls_17)
mmd_18 = mmd.mmd_linear(src_cls_18, tar_cls_18)
mmd_19 = mmd.mmd_linear(src_cls_19, tar_cls_19)
mmd_20 = mmd.mmd_linear(src_cls_20, tar_cls_20)
mmd_21 = mmd.mmd_linear(src_cls_21, tar_cls_21)
mmd_22 = mmd.mmd_linear(src_cls_22, tar_cls_22)
mmd_23 = mmd.mmd_linear(src_cls_23, tar_cls_23)
mmd_24 = mmd.mmd_linear(src_cls_24, tar_cls_24)
mmd_25 = mmd.mmd_linear(src_cls_25, tar_cls_25)
mmd_26 = mmd.mmd_linear(src_cls_26, tar_cls_26)
mmd_27 = mmd.mmd_linear(src_cls_27, tar_cls_27)
mmd_28 = mmd.mmd_linear(src_cls_28, tar_cls_28)
mmd_29 = mmd.mmd_linear(src_cls_29, tar_cls_29)
mmdc1 = mmd_10 + mmd_11 + mmd_12 + mmd_13 + mmd_14 + mmd_15 + mmd_16 + mmd_17 + mmd_18 + mmd_19
mmdc2 = mmd_20 + mmd_21 + mmd_22 + mmd_23 + mmd_24 + mmd_25 + mmd_26 + mmd_27 + mmd_28 + mmd_29
return (mmdc2) / 10
# return (mmdc1 + mmdc2) / 20
'''得到源域每类特征用于计算mmdc'''
def get_src_mean_feature(src_feature, shot, cls):
src_feature_list = []
for i in range(cls):
src_feature_cls = torch.mean(src_feature[shot * i: shot * (i + 1)], dim=0)
src_feature_list.append(src_feature_cls)
return src_feature_list
def get_mmdc(src_feature, tar_feature, tar_pseudo_label, batch_size, shot, cls):
src_feature_list = get_src_mean_feature(src_feature, shot, cls)
pseudo_label = tar_pseudo_label.cpu().detach().numpy()
mmdc = 0.0
for i in range(batch_size):
# mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
mmdc += mmd.mmd_linear(src_feature_list[pseudo_label[i]].reshape(1, -1), tar_feature[i].reshape(1, -1))
return mmdc / batch_size
class BCE(nn.Module):
eps = 1e-7
def forward(self, prob1, prob2, simi):
P = prob1.mul_(prob2)
P = P.sum(1)
P.mul_(simi).add_(simi.eq(-1).type_as(P))
neglogP = -P.add_(BCE.eps).log_()
return neglogP.mean()
class BinaryCrossEntropyLoss(nn.Module):
""" Construct binary cross-entropy loss."""
eps = 1e-7
def forward(self, prob):
# ds = torch.ones([bs, 1]).to(device) # domain label for source
# dt = torch.zeros([bs, 1]).to(device) # domain label for target
# di = torch.cat((ds, dt), dim=0).to(device)
# neglogP = - (di * torch.log(prob + BCE.eps) + (1. - di) * torch.log(1. - prob + BCE.eps))
neglogP = - (prob * torch.log(prob + BCE.eps) + (1. - prob) * torch.log(1. - prob + BCE.eps))
return neglogP.mean()