238 lines
7.5 KiB
Python
238 lines
7.5 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/11/10 10:46
|
|
@Usage :
|
|
@Desc : convLSTM 2D基本实现
|
|
'''
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
from RUL.baseModel.dctAttention import dct_channel_block
|
|
|
|
|
|
class dctLSTMCell(nn.Module):
|
|
|
|
def __init__(self, input_dim, hidden_dim, bias):
|
|
"""
|
|
Initialize ConvLSTM cell.
|
|
|
|
Parameters
|
|
----------
|
|
input_dim: int
|
|
Number of channels of input tensor.
|
|
hidden_dim: int
|
|
Number of channels of hidden state.
|
|
kernel_size: int
|
|
Size of the convolutional kernel.
|
|
bias: bool
|
|
Whether or not to add the bias.
|
|
|
|
Input:
|
|
A tensor of size B, T, C
|
|
B: bacth_size
|
|
T: timestamp
|
|
C: channel
|
|
"""
|
|
|
|
super(dctLSTMCell, self).__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.bias = bias
|
|
|
|
self.hidden = nn.Linear(in_features=self.input_dim + self.hidden_dim,
|
|
out_features=4 * self.hidden_dim,
|
|
bias=self.bias)
|
|
self.attention = dct_channel_block(channel=self.input_dim + self.hidden_dim)
|
|
|
|
def forward(self, input_tensor, cur_state):
|
|
# shape :b,c
|
|
h_cur, c_cur = cur_state
|
|
|
|
combined = torch.cat([input_tensor, h_cur], dim=-1) # concatenate along channel axis
|
|
# 增加一个channelAttention
|
|
combined = self.attention(combined)
|
|
combined_linear = self.hidden(combined)
|
|
cc_i, cc_f, cc_o, cc_g = torch.split(combined_linear, self.hidden_dim, dim=-1)
|
|
i = torch.sigmoid(cc_i)
|
|
f = torch.sigmoid(cc_f)
|
|
o = torch.sigmoid(cc_o)
|
|
g = torch.tanh(cc_g)
|
|
|
|
c_next = f * c_cur + i * g
|
|
h_next = o * torch.tanh(c_next)
|
|
|
|
return h_next, c_next
|
|
|
|
def init_hidden(self, batch_size):
|
|
return (torch.zeros(batch_size, self.hidden_dim, device=self.hidden.weight.device),
|
|
torch.zeros(batch_size, self.hidden_dim, device=self.hidden.weight.device))
|
|
|
|
|
|
class LSTM(nn.Module):
|
|
"""
|
|
|
|
Parameters:
|
|
input_dim: Number of channels in input
|
|
hidden_dim: Number of hidden channels
|
|
kernel_size: Size of kernel in convolutions
|
|
num_layers: Number of LSTM layers stacked on each other
|
|
batch_first: Whether or not dimension 0 is the batch or not
|
|
bias: Bias or no bias in Convolution
|
|
return_all_layers: Return the list of computations for all layers
|
|
Note: Will do same padding.
|
|
|
|
Input:
|
|
A tensor of size B, T, C or T, B, C
|
|
Output:
|
|
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
|
|
0 - layer_output_list is the list of lists of length T of each output
|
|
1 - last_state_list is the list of last states
|
|
each element of the list is a tuple (h, c) for hidden state and memory
|
|
Example:
|
|
>> x = torch.rand((32, 10, 64))
|
|
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
|
|
>> _, last_states = convlstm(x)
|
|
>> h = last_states[0][0] # 0 for layer index, 0 for h index
|
|
"""
|
|
|
|
def __init__(self, input_dim, hidden_dim, num_layers,
|
|
batch_first=False, bias=True, return_all_layers=False):
|
|
super(LSTM, self).__init__()
|
|
|
|
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
|
|
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
|
|
if not len(hidden_dim) == num_layers:
|
|
raise ValueError('Inconsistent list length.')
|
|
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.num_layers = num_layers
|
|
self.batch_first = batch_first
|
|
self.bias = bias
|
|
self.return_all_layers = return_all_layers
|
|
|
|
cell_list = []
|
|
|
|
for i in range(0, self.num_layers):
|
|
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
|
|
|
|
cell_list.append(
|
|
dctLSTMCell(input_dim=cur_input_dim,
|
|
hidden_dim=self.hidden_dim[i],
|
|
bias=self.bias),
|
|
)
|
|
|
|
self.cell_list = nn.ModuleList(cell_list)
|
|
|
|
def forward(self, input_tensor, hidden_state=None):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
input_tensor: todo
|
|
5-D Tensor either of shape (t, b, c) or (b, t, c)
|
|
hidden_state: todo
|
|
None. todo implement stateful
|
|
|
|
Returns
|
|
-------
|
|
last_state_list, layer_output
|
|
"""
|
|
|
|
if not self.batch_first:
|
|
# 等同于transpose
|
|
# (t, b, c, h, w) -> (b, t, c, h, w)
|
|
input_tensor = input_tensor.permute(1, 0, 2)
|
|
|
|
b, _, _ = input_tensor.size()
|
|
|
|
# Implement stateful ConvLSTM
|
|
if hidden_state is not None:
|
|
raise NotImplementedError()
|
|
else:
|
|
# Since the init is done in forward. Can send image size here
|
|
hidden_state = self._init_hidden(batch_size=b)
|
|
|
|
layer_output_list = []
|
|
last_state_list = []
|
|
|
|
timestamp = input_tensor.size(1)
|
|
cur_layer_input = input_tensor
|
|
|
|
for layer_idx in range(self.num_layers):
|
|
|
|
h, c = hidden_state[layer_idx]
|
|
output_inner = []
|
|
for t in range(timestamp):
|
|
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :],
|
|
cur_state=[h, c])
|
|
output_inner.append(h)
|
|
|
|
layer_output = torch.stack(output_inner, dim=1)
|
|
|
|
# TODO 每层之间增加一个dct_attention
|
|
# layer_output = self.attention_list[layer_idx](layer_output)
|
|
|
|
cur_layer_input = layer_output
|
|
|
|
layer_output_list.append(layer_output)
|
|
last_state_list.append([h, c])
|
|
|
|
if not self.return_all_layers:
|
|
layer_output_list = layer_output_list[-1:]
|
|
last_state_list = last_state_list[-1:]
|
|
|
|
return layer_output_list, last_state_list
|
|
|
|
def _init_hidden(self, batch_size):
|
|
init_states = []
|
|
for i in range(self.num_layers):
|
|
init_states.append(self.cell_list[i].init_hidden(batch_size))
|
|
return init_states
|
|
|
|
@staticmethod
|
|
def _extend_for_multilayer(param, num_layers):
|
|
if not isinstance(param, list):
|
|
param = [param] * num_layers
|
|
return param
|
|
|
|
|
|
class PredictModel(nn.Module):
|
|
|
|
def __init__(self, input_dim):
|
|
super(PredictModel, self).__init__()
|
|
|
|
self.lstm = LSTM(input_dim=input_dim, hidden_dim=[64, 64], num_layers=2, batch_first=True, bias=True,
|
|
return_all_layers=False)
|
|
|
|
self.backbone = nn.Sequential(
|
|
nn.Linear(in_features=64, out_features=64),
|
|
nn.Linear(in_features=64, out_features=64),
|
|
nn.BatchNorm1d(64),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(in_features=64, out_features=1)
|
|
)
|
|
|
|
def forward(self, input_tensor):
|
|
input_tensor = input_tensor.to(torch.float32)
|
|
layer_output_list, last_states = self.lstm(input_tensor)
|
|
last_timestamp = last_states[0][0]
|
|
predict = self.backbone(last_timestamp)
|
|
return predict
|
|
|
|
|
|
if __name__ == '__main__':
|
|
x = torch.rand((32, 10, 64))
|
|
lstm = LSTM(input_dim=64, hidden_dim=16, num_layers=1, batch_first=True, bias=True,
|
|
return_all_layers=False)
|
|
layer_output_list, last_states = lstm(x)
|
|
|
|
all = layer_output_list[0]
|
|
h = last_states[0][0]
|
|
print(all.size())
|
|
print(h.size())
|