# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/10 13:00 @Usage : @Desc : 构建一些即插即用的channelAttention ''' import torch.nn as nn import math import numpy as np import torch import torch_dct as dct try: from torch import irfft from torch import rfft except ImportError: def rfft(x, d): t = torch.fft.fft(x, dim=(-d)) r = torch.stack((t.real, t.imag), -1) return r def irfft(x, d): t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d)) return t.real # def dct(x, norm=None): # """ # Discrete Cosine Transform, Type II (a.k.a. the DCT) # # For the meaning of the parameter `norm`, see: # https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html # # :param x: the input signal # :param norm: the normalization, None or 'ortho' # :return: the DCT-II of the signal over the last dimension # """ # x_shape = x.shape # N = x_shape[-1] # x = x.contiguous().view(-1, N) # # v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) # # # Vc = torch.fft.rfft(v, 1, onesided=False) # Vc = rfft(v, 1) # # k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) # W_r = torch.cos(k) # W_i = torch.sin(k) # # V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i # # if norm == 'ortho': # V[:, 0] /= np.sqrt(N) * 2 # V[:, 1:] /= np.sqrt(N / 2) * 2 # # V = 2 * V.view(*x_shape) # # return V class dct_channel_block(nn.Module): def __init__(self, channel): super(dct_channel_block, self).__init__() # self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation self.fc = nn.Sequential( nn.Linear(channel, channel * 2, bias=False), nn.Dropout(p=0.1), nn.ReLU(inplace=True), nn.Linear(channel * 2, channel, bias=False), nn.Sigmoid() ) # self.dct_norm = nn.LayerNorm([512], eps=1e-6) self.dct_norm = nn.LayerNorm([channel], eps=1e-6) # for lstm on length-wise # self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36 def forward(self, x): b, c = x.size() # (B,C,L) (32,96,512) # list = [] # for i in range(c): # freq = dct.dct(x[:, :, i]) # list.append(freq) # # stack_dct = torch.stack(list, dim=2) # change = x.transpose(2, 1) stack_dct = dct.dct(x,norm='ortho') # stack_dct = stack_dct.transpose(2, 1) # stack_dct = torch.tensor(stack_dct) ''' for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets ''' lr_weight = self.dct_norm(stack_dct) lr_weight = self.fc(stack_dct) lr_weight = self.dct_norm(lr_weight) # print("lr_weight",lr_weight.shape) return x * lr_weight # result class dct_channel_block_withConv(nn.Module): def __init__(self, channel): super(dct_channel_block_withConv, self).__init__() # self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation self.fc = nn.Sequential( nn.Linear(channel, channel * 2, bias=False), nn.Dropout(p=0.1), nn.ReLU(inplace=True), nn.Linear(channel * 2, channel, bias=False), nn.Sigmoid() ) # self.dct_norm = nn.LayerNorm([512], eps=1e-6) self.dct_norm = nn.LayerNorm([channel], eps=1e-6) # for lstm on length-wise # self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36 def forward(self, x): b, c = x.size() # (B,C,L) (32,96,512) # list = [] # for i in range(c): # freq = dct.dct(x[:, :, i]) # list.append(freq) # # stack_dct = torch.stack(list, dim=2) # change = x.transpose(2, 1) stack_dct = dct.dct(x,norm='ortho') # stack_dct = stack_dct.transpose(2, 1) # stack_dct = torch.tensor(stack_dct) ''' for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets ''' lr_weight = self.dct_norm(stack_dct) lr_weight = self.fc(stack_dct) lr_weight = self.dct_norm(lr_weight) # print("lr_weight",lr_weight.shape) return x * lr_weight # result if __name__ == '__main__': # input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3] x = torch.rand((32, 10, 64)) print(x.shape) m = nn.Linear(64, 2) output = m(x) print(output.shape) # [2, 2]