# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/11/10 10:46 @Usage : @Desc : convLSTM 2D基本实现 ''' import torch.nn as nn import torch from RUL.baseModel.dctChannelAttention import dct_channel_block class LSTMCell(nn.Module): def __init__(self, input_dim, hidden_dim, bias): """ Initialize ConvLSTM cell. Parameters ---------- input_dim: int Number of channels of input tensor. hidden_dim: int Number of channels of hidden state. kernel_size: int Size of the convolutional kernel. bias: bool Whether or not to add the bias. Input: A tensor of size B, T, C B: bacth_size T: timestamp C: channel """ super(LSTMCell, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.bias = bias self.hidden = nn.Linear(in_features=self.input_dim + self.hidden_dim, out_features=4 * self.hidden_dim, bias=self.bias) def forward(self, input_tensor, cur_state): # shape :b,c h_cur, c_cur = cur_state combined = torch.cat([input_tensor, h_cur], dim=-1) # concatenate along channel axis combined_linear = self.hidden(combined) cc_i, cc_f, cc_o, cc_g = torch.split(combined_linear, self.hidden_dim, dim=1) i = torch.sigmoid(cc_i) f = torch.sigmoid(cc_f) o = torch.sigmoid(cc_o) g = torch.tanh(cc_g) c_next = f * c_cur + i * g h_next = o * torch.tanh(c_next) return h_next, c_next def init_hidden(self, batch_size): return (torch.zeros(batch_size, self.hidden_dim, device=self.hidden.weight.device), torch.zeros(batch_size, self.hidden_dim, device=self.hidden.weight.device)) class LSTM(nn.Module): """ Parameters: input_dim: Number of channels in input hidden_dim: Number of hidden channels kernel_size: Size of kernel in convolutions num_layers: Number of LSTM layers stacked on each other batch_first: Whether or not dimension 0 is the batch or not bias: Bias or no bias in Convolution return_all_layers: Return the list of computations for all layers Note: Will do same padding. Input: A tensor of size B, T, C or T, B, C Output: A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). 0 - layer_output_list is the list of lists of length T of each output 1 - last_state_list is the list of last states each element of the list is a tuple (h, c) for hidden state and memory Example: >> x = torch.rand((32, 10, 64)) >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) >> _, last_states = convlstm(x) >> h = last_states[0][0] # 0 for layer index, 0 for h index """ def __init__(self, input_dim, hidden_dim, num_layers, batch_first=False, bias=True, return_all_layers=False): super(LSTM, self).__init__() # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) if not len(hidden_dim) == num_layers: raise ValueError('Inconsistent list length.') self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.batch_first = batch_first self.bias = bias self.return_all_layers = return_all_layers cell_list = [] attention_list = [] for i in range(0, self.num_layers): cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] cell_list.append( LSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], bias=self.bias), ) attention_list.append( dct_channel_block(self.hidden_dim[i]) ) self.attention_list = nn.ModuleList(attention_list) self.cell_list = nn.ModuleList(cell_list) def forward(self, input_tensor, hidden_state=None): """ Parameters ---------- input_tensor: todo 5-D Tensor either of shape (t, b, c) or (b, t, c) hidden_state: todo None. todo implement stateful Returns ------- last_state_list, layer_output """ if not self.batch_first: # 等同于transpose # (t, b, c, h, w) -> (b, t, c, h, w) input_tensor = input_tensor.permute(1, 0, 2) b, _, _ = input_tensor.size() # Implement stateful ConvLSTM if hidden_state is not None: raise NotImplementedError() else: # Since the init is done in forward. Can send image size here hidden_state = self._init_hidden(batch_size=b) layer_output_list = [] last_state_list = [] timestamp = input_tensor.size(1) cur_layer_input = input_tensor for layer_idx in range(self.num_layers): h, c = hidden_state[layer_idx] output_inner = [] for t in range(timestamp): h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :], cur_state=[h, c]) output_inner.append(h) layer_output = torch.stack(output_inner, dim=1) # TODO 每层之间增加一个dct_attention layer_output = self.attention_list[layer_idx](layer_output) cur_layer_input = layer_output layer_output_list.append(layer_output) last_state_list.append([h, c]) if not self.return_all_layers: layer_output_list = layer_output_list[-1:] last_state_list = last_state_list[-1:] return layer_output_list, last_state_list def _init_hidden(self, batch_size): init_states = [] for i in range(self.num_layers): init_states.append(self.cell_list[i].init_hidden(batch_size)) return init_states @staticmethod def _extend_for_multilayer(param, num_layers): if not isinstance(param, list): param = [param] * num_layers return param class PredictModel(nn.Module): def __init__(self, input_dim): super(PredictModel, self).__init__() self.lstm = LSTM(input_dim=input_dim, hidden_dim=[64, 64], num_layers=2, batch_first=True, bias=True, return_all_layers=False) self.backbone = nn.Sequential( nn.Linear(in_features=64, out_features=64), nn.Linear(in_features=64, out_features=64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(in_features=64, out_features=1) ) def forward(self, input_tensor): input_tensor = input_tensor.to(torch.float32) layer_output_list, last_states = self.lstm(input_tensor) last_timestamp = last_states[0][0] predict = self.backbone(last_timestamp) return predict if __name__ == '__main__': x = torch.rand((32, 10, 64)) lstm = LSTM(input_dim=64, hidden_dim=16, num_layers=1, batch_first=True, bias=True, return_all_layers=False) layer_output_list, last_states = lstm(x) all = layer_output_list[0] h = last_states[0][0] print(all.size()) print(h.size())