self_example/pytorch_example/RUL/baseModel/loss/cos.py

15 lines
290 B
Python

#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/15 14:47
@Usage :
@Desc :
'''
import torch.nn as nn
def cosine(source, target):
source, target = source.mean(0), target.mean(0)
cos = nn.CosineSimilarity(dim=0)
loss = cos(source, target)
return loss.mean()