#-*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/15 14:47 @Usage : @Desc : ''' import torch.nn as nn def cosine(source, target): source, target = source.mean(0), target.mean(0) cos = nn.CosineSimilarity(dim=0) loss = cos(source, target) return loss.mean()