self_example/pytorch_example/RUL/baseModel/dctChannelAttention.py

160 lines
4.7 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/10 13:00
@Usage :
@Desc : 构建一些即插即用的channelAttention
'''
import torch.nn as nn
import math
import numpy as np
import torch
import torch_dct as dct
try:
from torch import irfft
from torch import rfft
except ImportError:
def rfft(x, d):
t = torch.fft.fft(x, dim=(-d))
r = torch.stack((t.real, t.imag), -1)
return r
def irfft(x, d):
t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
return t.real
# def dct(x, norm=None):
# """
# Discrete Cosine Transform, Type II (a.k.a. the DCT)
#
# For the meaning of the parameter `norm`, see:
# https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
#
# :param x: the input signal
# :param norm: the normalization, None or 'ortho'
# :return: the DCT-II of the signal over the last dimension
# """
# x_shape = x.shape
# N = x_shape[-1]
# x = x.contiguous().view(-1, N)
#
# v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
#
# # Vc = torch.fft.rfft(v, 1, onesided=False)
# Vc = rfft(v, 1)
#
# k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
# W_r = torch.cos(k)
# W_i = torch.sin(k)
#
# V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
#
# if norm == 'ortho':
# V[:, 0] /= np.sqrt(N) * 2
# V[:, 1:] /= np.sqrt(N / 2) * 2
#
# V = 2 * V.view(*x_shape)
#
# return V
class dct_channel_block(nn.Module):
def __init__(self, channel):
super(dct_channel_block, self).__init__()
# self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation
self.fc = nn.Sequential(
nn.Linear(channel, channel * 2, bias=False),
nn.Dropout(p=0.1),
nn.ReLU(inplace=True),
nn.Linear(channel * 2, channel, bias=False),
nn.Sigmoid()
)
# self.dct_norm = nn.LayerNorm([512], eps=1e-6)
self.dct_norm = nn.LayerNorm([channel], eps=1e-6) # for lstm on length-wise
# self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36
def forward(self, x):
b, t, c = x.size() # (B,C,L) (32,96,512)
# list = []
# for i in range(c):
# freq = dct.dct(x[:, :, i])
# list.append(freq)
#
# stack_dct = torch.stack(list, dim=2)
change = x.transpose(2, 1)
stack_dct = dct.dct(change,norm='ortho')
stack_dct = stack_dct.transpose(2, 1)
# stack_dct = torch.tensor(stack_dct)
'''
for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets
'''
lr_weight = self.dct_norm(stack_dct)
lr_weight = self.fc(stack_dct)
lr_weight = self.dct_norm(lr_weight)
# print("lr_weight",lr_weight.shape)
return x * lr_weight # result
class dct_channel_block1(nn.Module):
def __init__(self, channel):
super(dct_channel_block1, self).__init__()
# self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation
self.fc = nn.Sequential(
nn.Linear(channel, channel * 2, bias=False),
nn.Dropout(p=0.1),
nn.ReLU(inplace=True),
nn.Linear(channel * 2, channel, bias=False),
nn.Sigmoid()
)
# self.dct_norm = nn.LayerNorm([512], eps=1e-6)
self.dct_norm = nn.LayerNorm([96], eps=1e-6) # for lstm on length-wise
# self.dct_norm = nn.LayerNorm([36], eps=1e-6)#for lstm on length-wise on ill with input =36
def forward(self, x):
b, c, l = x.size() # (B,C,L) (32,96,512)
# y = self.avg_pool(x) # (B,C,L) -> (B,C,1)
# y = self.avg_pool(x).view(b, c) # (B,C,L) -> (B,C,1)
# print("y",y.shape
# y = self.fc(y).view(b, c, 96)
list = []
for i in range(c):
freq = dct.dct(x[:, i, :])
# print("freq-shape:",freq.shape)
list.append(freq)
stack_dct = torch.stack(list, dim=1)
stack_dct = torch.tensor(stack_dct)
'''
for traffic mission:f_weight = self.dct_norm(f_weight.permute(0,2,1))#matters for traffic datasets
'''
lr_weight = self.dct_norm(stack_dct)
lr_weight = self.fc(stack_dct)
lr_weight = self.dct_norm(lr_weight)
# print("lr_weight",lr_weight.shape)
return x * lr_weight # result
if __name__ == '__main__':
# input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3]
x = torch.rand((8, 7, 96))
dct_model = dct_channel_block1(7)
result = dct_model.forward(x)
print(result)
# print(x.shape)
# m = nn.Linear(64, 2)
# output = m(x)
# print(output.shape) # [2, 2]