115 lines
3.1 KiB
Python
115 lines
3.1 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/10 13:00
|
|
@Usage :
|
|
@Desc : 构建一些即插即用的channelAttention
|
|
'''
|
|
|
|
import torch.nn as nn
|
|
import math
|
|
import numpy as np
|
|
import torch
|
|
import torch_dct as dct
|
|
|
|
try:
|
|
from torch import irfft
|
|
from torch import rfft
|
|
except ImportError:
|
|
def rfft(x, d):
|
|
t = torch.fft.fft(x, dim=(-d))
|
|
r = torch.stack((t.real, t.imag), -1)
|
|
return r
|
|
|
|
|
|
def irfft(x, d):
|
|
t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
|
|
return t.real
|
|
|
|
|
|
# def dct(x, norm=None):
|
|
# """
|
|
# Discrete Cosine Transform, Type II (a.k.a. the DCT)
|
|
#
|
|
# For the meaning of the parameter `norm`, see:
|
|
# https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
|
#
|
|
# :param x: the input signal
|
|
# :param norm: the normalization, None or 'ortho'
|
|
# :return: the DCT-II of the signal over the last dimension
|
|
# """
|
|
# x_shape = x.shape
|
|
# N = x_shape[-1]
|
|
# x = x.contiguous().view(-1, N)
|
|
#
|
|
# v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
|
|
#
|
|
# # Vc = torch.fft.rfft(v, 1, onesided=False)
|
|
# Vc = rfft(v, 1)
|
|
#
|
|
# k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
|
|
# W_r = torch.cos(k)
|
|
# W_i = torch.sin(k)
|
|
#
|
|
# V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
|
|
#
|
|
# if norm == 'ortho':
|
|
# V[:, 0] /= np.sqrt(N) * 2
|
|
# V[:, 1:] /= np.sqrt(N / 2) * 2
|
|
#
|
|
# V = 2 * V.view(*x_shape)
|
|
#
|
|
# return V
|
|
|
|
class dct_channel_block(nn.Module):
|
|
def __init__(self, channel):
|
|
super(dct_channel_block, self).__init__()
|
|
# self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(channel, channel * 2, bias=False),
|
|
nn.Dropout(p=0.1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(channel * 2, channel, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
# self.dct_norm = nn.LayerNorm([512], eps=1e-6)
|
|
|
|
self.dct_norm = nn.LayerNorm([channel], eps=1e-6) # for lstm on length-wise
|
|
# self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36
|
|
|
|
def forward(self, x):
|
|
b, c = x.size() # (B,C,L) (32,96,512)
|
|
|
|
# list = []
|
|
# for i in range(c):
|
|
# freq = dct.dct(x[:, :, i])
|
|
# list.append(freq)
|
|
#
|
|
# stack_dct = torch.stack(list, dim=2)
|
|
|
|
# change = x.transpose(2, 1)
|
|
stack_dct = dct.dct(x,norm='ortho')
|
|
# stack_dct = stack_dct.transpose(2, 1)
|
|
|
|
# stack_dct = torch.tensor(stack_dct)
|
|
'''
|
|
for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets
|
|
'''
|
|
|
|
lr_weight = self.dct_norm(stack_dct)
|
|
lr_weight = self.fc(stack_dct)
|
|
lr_weight = self.dct_norm(lr_weight)
|
|
|
|
# print("lr_weight",lr_weight.shape)
|
|
return x * lr_weight # result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3]
|
|
x = torch.rand((32, 10, 64))
|
|
print(x.shape)
|
|
m = nn.Linear(64, 2)
|
|
output = m(x)
|
|
print(output.shape) # [2, 2]
|