增加一些scala和python相关
This commit is contained in:
parent
a9c30c8a91
commit
14b7cb72a6
|
|
@ -59,6 +59,12 @@ from scipy.fftpack import dct
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
'''
|
||||
参考:
|
||||
[1] https://github.com/Zero-coder/FECAM/blob/main/layers/dctnet.py
|
||||
[2] https://arxiv.org/pdf/2212.01209v1.pdf
|
||||
'''
|
||||
|
||||
def sdct_tf(signals, frame_length, frame_step, window_fn=tf.signal.hamming_window):
|
||||
"""Compute Short-Time Discrete Cosine Transform of `signals`.
|
||||
|
||||
|
|
@ -130,6 +136,7 @@ class DCTChannelAttention(layers.Layer):
|
|||
self.drop1 = Dropout(0.1)
|
||||
self.relu = ReLU(0.1)
|
||||
self.l2 = Dense(channel, use_bias=False)
|
||||
self.bn = BatchNormalization()
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
batch_size, hidden, channel = inputs.shape
|
||||
|
|
@ -141,11 +148,13 @@ class DCTChannelAttention(layers.Layer):
|
|||
# list.append(freq)
|
||||
# stack_dct = tf.stack(list, dim=1)
|
||||
|
||||
lr_weight = self.bn(stack_dct)
|
||||
lr_weight = self.l1(stack_dct)
|
||||
lr_weight = self.drop1(lr_weight)
|
||||
lr_weight = self.relu(lr_weight)
|
||||
lr_weight = self.l2(lr_weight)
|
||||
lr_weight = tf.sigmoid(lr_weight)
|
||||
|
||||
lr_weight = BatchNormalization()(lr_weight)
|
||||
lr_weight = self.bn(lr_weight)
|
||||
|
||||
return inputs * lr_weight
|
||||
|
|
|
|||
Loading…
Reference in New Issue