self_example/pytorch_example/RUL/baseModel/loss/ffd.py

41 lines
826 B
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/20 15:44
@Usage :
@Desc : fft 频域Loss
'''
import torch
import torch.nn as nn
import torch_dct as dct
import torch.fft as fft
def fft_mse(source, target):
if len(source.shape) < 2:
length = 1
else:
_, length = source.shape
source = fft.rfft(source)
target = fft.rfft(target)
source = torch.abs(source / length)
target = torch.abs(target / length)
source, target = source.mean(0), target.mean(0)
mse = nn.MSELoss()
loss = mse(source, target)
return loss.mean()
pass
def dct_mse(source, target):
source = dct.dct(source)
target = dct.dct(target)
source, target = source.mean(0), target.mean(0)
mse = nn.MSELoss()
loss = mse(source, target)
return loss.mean()
pass