增加一些scala和python相关

This commit is contained in:
dingjiawen@xiaomi.com 2023-08-16 16:20:04 +08:00
parent a9c30c8a91
commit 14b7cb72a6
1 changed files with 10 additions and 1 deletions

View File

@ -59,6 +59,12 @@ from scipy.fftpack import dct
import tensorflow as tf
'''
参考
[1] https://github.com/Zero-coder/FECAM/blob/main/layers/dctnet.py
[2] https://arxiv.org/pdf/2212.01209v1.pdf
'''
def sdct_tf(signals, frame_length, frame_step, window_fn=tf.signal.hamming_window):
"""Compute Short-Time Discrete Cosine Transform of `signals`.
@ -130,6 +136,7 @@ class DCTChannelAttention(layers.Layer):
self.drop1 = Dropout(0.1)
self.relu = ReLU(0.1)
self.l2 = Dense(channel, use_bias=False)
self.bn = BatchNormalization()
def call(self, inputs, **kwargs):
batch_size, hidden, channel = inputs.shape
@ -141,11 +148,13 @@ class DCTChannelAttention(layers.Layer):
# list.append(freq)
# stack_dct = tf.stack(list, dim=1)
lr_weight = self.bn(stack_dct)
lr_weight = self.l1(stack_dct)
lr_weight = self.drop1(lr_weight)
lr_weight = self.relu(lr_weight)
lr_weight = self.l2(lr_weight)
lr_weight = tf.sigmoid(lr_weight)
lr_weight = BatchNormalization()(lr_weight)
lr_weight = self.bn(lr_weight)
return inputs * lr_weight