diff --git a/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py b/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py index 8752bc6..fd54eff 100644 --- a/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py +++ b/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py @@ -59,6 +59,12 @@ from scipy.fftpack import dct import tensorflow as tf +''' +参考: +[1] https://github.com/Zero-coder/FECAM/blob/main/layers/dctnet.py +[2] https://arxiv.org/pdf/2212.01209v1.pdf +''' + def sdct_tf(signals, frame_length, frame_step, window_fn=tf.signal.hamming_window): """Compute Short-Time Discrete Cosine Transform of `signals`. @@ -130,6 +136,7 @@ class DCTChannelAttention(layers.Layer): self.drop1 = Dropout(0.1) self.relu = ReLU(0.1) self.l2 = Dense(channel, use_bias=False) + self.bn = BatchNormalization() def call(self, inputs, **kwargs): batch_size, hidden, channel = inputs.shape @@ -141,11 +148,13 @@ class DCTChannelAttention(layers.Layer): # list.append(freq) # stack_dct = tf.stack(list, dim=1) + lr_weight = self.bn(stack_dct) lr_weight = self.l1(stack_dct) lr_weight = self.drop1(lr_weight) lr_weight = self.relu(lr_weight) lr_weight = self.l2(lr_weight) + lr_weight = tf.sigmoid(lr_weight) - lr_weight = BatchNormalization()(lr_weight) + lr_weight = self.bn(lr_weight) return inputs * lr_weight