From 14b7cb72a6cabd461f2ca7f4bb9dc6c703d94bf6 Mon Sep 17 00:00:00 2001 From: "dingjiawen@xiaomi.com" Date: Wed, 16 Aug 2023 16:20:04 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=80=E4=BA=9Bscala?= =?UTF-8?q?=E5=92=8Cpython=E7=9B=B8=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/ChannelAttention/DCT_channelAttention.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py b/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py index 8752bc6..fd54eff 100644 --- a/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py +++ b/TensorFlow_eaxmple/Model_train_test/model/ChannelAttention/DCT_channelAttention.py @@ -59,6 +59,12 @@ from scipy.fftpack import dct import tensorflow as tf +''' +参考: +[1] https://github.com/Zero-coder/FECAM/blob/main/layers/dctnet.py +[2] https://arxiv.org/pdf/2212.01209v1.pdf +''' + def sdct_tf(signals, frame_length, frame_step, window_fn=tf.signal.hamming_window): """Compute Short-Time Discrete Cosine Transform of `signals`. @@ -130,6 +136,7 @@ class DCTChannelAttention(layers.Layer): self.drop1 = Dropout(0.1) self.relu = ReLU(0.1) self.l2 = Dense(channel, use_bias=False) + self.bn = BatchNormalization() def call(self, inputs, **kwargs): batch_size, hidden, channel = inputs.shape @@ -141,11 +148,13 @@ class DCTChannelAttention(layers.Layer): # list.append(freq) # stack_dct = tf.stack(list, dim=1) + lr_weight = self.bn(stack_dct) lr_weight = self.l1(stack_dct) lr_weight = self.drop1(lr_weight) lr_weight = self.relu(lr_weight) lr_weight = self.l2(lr_weight) + lr_weight = tf.sigmoid(lr_weight) - lr_weight = BatchNormalization()(lr_weight) + lr_weight = self.bn(lr_weight) return inputs * lr_weight