# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/20 15:44 @Usage : @Desc : fft 频域Loss ''' import torch import torch.nn as nn import torch_dct as dct import torch.fft as fft def fft_mse(source, target): if len(source.shape) < 2: length = 1 else: _, length = source.shape source = fft.rfft(source) target = fft.rfft(target) source = torch.abs(source / length) target = torch.abs(target / length) source, target = source.mean(0), target.mean(0) mse = nn.MSELoss() loss = mse(source, target) return loss.mean() pass def dct_mse(source, target): source = dct.dct(source) target = dct.dct(target) source, target = source.mean(0), target.mean(0) mse = nn.MSELoss() loss = mse(source, target) return loss.mean() pass