self_example/pytorch_example/RUL/otherIdea/dctEmbedLSTM/model.py

244 lines
7.7 KiB
Python

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/10 10:46
@Usage :
@Desc : convLSTM 2D基本实现
'''
import torch.nn as nn
import torch
from RUL.baseModel.dctAttention import dct_channel_block
class dctLSTMCell(nn.Module):
def __init__(self, input_dim, hidden_dim, bias):
"""
Initialize ConvLSTM cell.
Parameters
----------
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Number of channels of hidden state.
kernel_size: int
Size of the convolutional kernel.
bias: bool
Whether or not to add the bias.
Input:
A tensor of size B, T, C
B: bacth_size
T: timestamp
C: channel
"""
super(dctLSTMCell, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.bias = bias
self.hidden = nn.Linear(in_features=self.input_dim + self.hidden_dim,
out_features=4 * self.hidden_dim,
bias=self.bias)
self.attention = dct_channel_block(channel=self.input_dim + self.hidden_dim)
def forward(self, input_tensor, cur_state):
# shape :b,c
h_cur, c_cur = cur_state
combined = torch.cat([input_tensor, h_cur], dim=-1) # concatenate along channel axis
# 增加一个channelAttention
combined = self.attention(combined)
combined_linear = self.hidden(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_linear, self.hidden_dim, dim=-1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size):
return (torch.zeros(batch_size, self.hidden_dim, device=self.hidden.weight.device),
torch.zeros(batch_size, self.hidden_dim, device=self.hidden.weight.device))
class LSTM(nn.Module):
"""
Parameters:
input_dim: Number of channels in input
hidden_dim: Number of hidden channels
kernel_size: Size of kernel in convolutions
num_layers: Number of LSTM layers stacked on each other
batch_first: Whether or not dimension 0 is the batch or not
bias: Bias or no bias in Convolution
return_all_layers: Return the list of computations for all layers
Note: Will do same padding.
Input:
A tensor of size B, T, C or T, B, C
Output:
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
0 - layer_output_list is the list of lists of length T of each output
1 - last_state_list is the list of last states
each element of the list is a tuple (h, c) for hidden state and memory
Example:
>> x = torch.rand((32, 10, 64))
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
>> _, last_states = convlstm(x)
>> h = last_states[0][0] # 0 for layer index, 0 for h index
"""
def __init__(self, input_dim, hidden_dim, num_layers,
batch_first=False, bias=True, return_all_layers=False):
super(LSTM, self).__init__()
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(
dctLSTMCell(input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
bias=self.bias),
)
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_tensor, hidden_state=None):
"""
Parameters
----------
input_tensor: todo
5-D Tensor either of shape (t, b, c) or (b, t, c)
hidden_state: todo
None. todo implement stateful
Returns
-------
last_state_list, layer_output
"""
if not self.batch_first:
# 等同于transpose
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.permute(1, 0, 2)
b, _, _ = input_tensor.size()
# Implement stateful ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
# Since the init is done in forward. Can send image size here
hidden_state = self._init_hidden(batch_size=b)
layer_output_list = []
last_state_list = []
timestamp = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h, c = hidden_state[layer_idx]
output_inner = []
for t in range(timestamp):
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :],
cur_state=[h, c])
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
# TODO 每层之间增加一个dct_attention
# layer_output = self.attention_list[layer_idx](layer_output)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h, c])
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.cell_list[i].init_hidden(batch_size))
return init_states
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
class PredictModel(nn.Module):
def __init__(self, input_dim):
super(PredictModel, self).__init__()
self.lstm = LSTM(input_dim=input_dim, hidden_dim=[512, 256], num_layers=2, batch_first=True, bias=True,
return_all_layers=False)
self.backbone = nn.Sequential(
nn.Linear(in_features=256, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=64),
nn.ReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(64),
nn.Linear(in_features=64, out_features=32),
nn.ReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Linear(in_features=32, out_features=16),
nn.Linear(in_features=16, out_features=1)
)
def forward(self, input_tensor):
input_tensor = input_tensor.to(torch.float32)
layer_output_list, last_states = self.lstm(input_tensor)
last_timestamp = last_states[0][0]
predict = self.backbone(last_timestamp)
return predict
if __name__ == '__main__':
x = torch.rand((32, 10, 64))
lstm = LSTM(input_dim=64, hidden_dim=16, num_layers=1, batch_first=True, bias=True,
return_all_layers=False)
layer_output_list, last_states = lstm(x)
all = layer_output_list[0]
h = last_states[0][0]
print(all.size())
print(h.size())