# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/10 13:00 @Usage : @Desc : 构建一些即插即用的channelAttention ''' import torch.nn as nn import math import numpy as np import torch import torch_dct as dct try: from torch import irfft from torch import rfft except ImportError: def rfft(x, d): t = torch.fft.fft(x, dim=(-d)) r = torch.stack((t.real, t.imag), -1) return r def irfft(x, d): t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d)) return t.real # def dct(x, norm=None): # """ # Discrete Cosine Transform, Type II (a.k.a. the DCT) # # For the meaning of the parameter `norm`, see: # https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html # # :param x: the input signal # :param norm: the normalization, None or 'ortho' # :return: the DCT-II of the signal over the last dimension # """ # x_shape = x.shape # N = x_shape[-1] # x = x.contiguous().view(-1, N) # # v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) # # # Vc = torch.fft.rfft(v, 1, onesided=False) # Vc = rfft(v, 1) # # k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) # W_r = torch.cos(k) # W_i = torch.sin(k) # # V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i # # if norm == 'ortho': # V[:, 0] /= np.sqrt(N) * 2 # V[:, 1:] /= np.sqrt(N / 2) * 2 # # V = 2 * V.view(*x_shape) # # return V class dct_channel_block(nn.Module): def __init__(self, channel): super(dct_channel_block, self).__init__() # self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation self.fc = nn.Sequential( nn.Linear(channel, channel * 2, bias=False), nn.Dropout(p=0.1), nn.ReLU(inplace=True), nn.Linear(channel * 2, channel, bias=False), nn.Sigmoid() ) # self.dct_norm = nn.LayerNorm([512], eps=1e-6) self.dct_norm = nn.LayerNorm([channel], eps=1e-6) # for lstm on length-wise # self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36 def forward(self, x): b, t, c = x.size() # (B,C,L) (32,96,512) # list = [] # for i in range(c): # freq = dct.dct(x[:, :, i]) # list.append(freq) # # stack_dct = torch.stack(list, dim=2) change = x.transpose(2, 1) stack_dct = dct.dct(change,norm='ortho') stack_dct = stack_dct.transpose(2, 1) # stack_dct = torch.tensor(stack_dct) ''' for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets ''' lr_weight = self.dct_norm(stack_dct) lr_weight = self.fc(stack_dct) lr_weight = self.dct_norm(lr_weight) # print("lr_weight",lr_weight.shape) return x * lr_weight # result class dct_channel_block1(nn.Module): def __init__(self, channel): super(dct_channel_block1, self).__init__() # self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation self.fc = nn.Sequential( nn.Linear(channel, channel * 2, bias=False), nn.Dropout(p=0.1), nn.ReLU(inplace=True), nn.Linear(channel * 2, channel, bias=False), nn.Sigmoid() ) # self.dct_norm = nn.LayerNorm([512], eps=1e-6) self.dct_norm = nn.LayerNorm([96], eps=1e-6) # for lstm on length-wise # self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36 def forward(self, x): b, c, l = x.size() # (B,C,L) (32,96,512) # y = self.avg_pool(x) # (B,C,L) -> (B,C,1) # y = self.avg_pool(x).view(b, c) # (B,C,L) -> (B,C,1) # print("y",y.shape # y = self.fc(y).view(b, c, 96) list = [] for i in range(c): freq = dct.dct(x[:, i, :]) # print("freq-shape:",freq.shape) list.append(freq) stack_dct = torch.stack(list, dim=1) stack_dct = torch.tensor(stack_dct) ''' for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets ''' lr_weight = self.dct_norm(stack_dct) lr_weight = self.fc(stack_dct) lr_weight = self.dct_norm(lr_weight) # print("lr_weight",lr_weight.shape) return x * lr_weight # result if __name__ == '__main__': # input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3] x = torch.rand((8, 7, 96)) dct_model = dct_channel_block1(7) result = dct_model.forward(x) print(result) # print(x.shape) # m = nn.Linear(64, 2) # output = m(x) # print(output.shape) # [2, 2]