|
#-*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/15 14:47
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
import torch.nn as nn
|
|
|
|
def cosine(source, target):
|
|
source, target = source.mean(0), target.mean(0)
|
|
cos = nn.CosineSimilarity(dim=0)
|
|
loss = cos(source, target)
|
|
return loss.mean() |