41 lines
826 B
Python
41 lines
826 B
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/20 15:44
|
|
@Usage :
|
|
@Desc : fft 频域Loss
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import torch_dct as dct
|
|
import torch.fft as fft
|
|
|
|
|
|
def fft_mse(source, target):
|
|
if len(source.shape) < 2:
|
|
length = 1
|
|
else:
|
|
_, length = source.shape
|
|
source = fft.rfft(source)
|
|
target = fft.rfft(target)
|
|
|
|
source = torch.abs(source / length)
|
|
target = torch.abs(target / length)
|
|
source, target = source.mean(0), target.mean(0)
|
|
mse = nn.MSELoss()
|
|
loss = mse(source, target)
|
|
return loss.mean()
|
|
pass
|
|
|
|
|
|
def dct_mse(source, target):
|
|
source = dct.dct(source)
|
|
target = dct.dct(target)
|
|
source, target = source.mean(0), target.mean(0)
|
|
mse = nn.MSELoss()
|
|
loss = mse(source, target)
|
|
return loss.mean()
|
|
pass
|