一些改造版RNN
This commit is contained in:
parent
d3b671b006
commit
418b6e6ee1
|
|
@ -1,5 +1,6 @@
|
|||
package com.atguigu.spark.core.rdd.operator.transform
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
object Spark01_RDD_Operator_Transform {
|
||||
|
|
@ -11,7 +12,7 @@ object Spark01_RDD_Operator_Transform {
|
|||
val sc =new SparkContext(sparkConf)
|
||||
|
||||
//TODO 算子 - map
|
||||
val rdd = sc.makeRDD(
|
||||
val rdd: RDD[Int] = sc.makeRDD(
|
||||
List(1,2,3,4)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# 线性RNN的相关变体
|
||||
|
||||
用bert4keras实现三个快速可并行的RNN变体:LRU、SLRU和RWKV。
|
||||
|
||||
## 简介
|
||||
|
||||
- 中文博客:https://kexue.fm/archives/9554
|
||||
- LRU论文:https://arxiv.org/abs/2303.06349
|
||||
- RWKV链接:https://github.com/BlinkDL/RWKV-LM
|
||||
|
||||
## 并行
|
||||
|
||||
线性RNN支持并行算法,可以将O(L)的运算降低到O(log L),本项目利用的是prefix sum问题的“Upper/Lower算法”来实现RNN并行。
|
||||
|
||||
具体细节可以参考中文博客的“[并行化](https://kexue.fm/archives/9554#%E5%B9%B6%E8%A1%8C%E5%8C%96)”一节
|
||||
|
||||
## 交流
|
||||
QQ交流群:808623966,微信群请加机器人微信号spaces_ac_cn
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#-*- encoding:utf-8 -*-
|
||||
|
||||
'''
|
||||
@Author : dingjiawen
|
||||
@Date : 2023/6/13 19:13
|
||||
@Usage :
|
||||
@Desc :
|
||||
'''
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
#! -*- coding: utf-8 -*-
|
||||
# 线性循环单元(Linear Recurrent Unit)
|
||||
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
|
||||
|
||||
from bert4keras.layers import *
|
||||
|
||||
|
||||
class LRU(Layer):
|
||||
"""线性循环单元
|
||||
链接1:https://arxiv.org/abs/2303.06349
|
||||
链接2:https://kexue.fm/archives/9554
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
units,
|
||||
activation='linear',
|
||||
use_bias=True,
|
||||
unroll=True, # unroll可以加速训练,但是会增加显存消耗
|
||||
kernel_initializer='glorot_uniform',
|
||||
**kwargs
|
||||
):
|
||||
super(LRU, self).__init__(**kwargs)
|
||||
self.units = units
|
||||
self.activation = activations.get(activation)
|
||||
self.use_bias = use_bias
|
||||
self.unroll = unroll
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
|
||||
@integerize_shape
|
||||
def build(self, input_shape):
|
||||
super(LRU, self).build(input_shape)
|
||||
hidden_size = input_shape[-1]
|
||||
self.i_dense = Dense(
|
||||
units=self.units * 2,
|
||||
use_bias=self.use_bias,
|
||||
kernel_initializer=self.kernel_initializer
|
||||
)
|
||||
self.o_dense = Dense(
|
||||
units=hidden_size,
|
||||
use_bias=self.use_bias,
|
||||
activation=self.activation,
|
||||
kernel_initializer=self.kernel_initializer
|
||||
)
|
||||
|
||||
def initializer(shape, dtype=None):
|
||||
r_min, r_max = 0.9, 0.999
|
||||
u1 = np.random.random(size=shape[1])
|
||||
u2 = np.random.random(size=shape[1])
|
||||
nu_log = np.log(
|
||||
-0.5 * np.log(u1 * (r_max**2 - r_min**2) + r_min**2)
|
||||
)
|
||||
theta_log = np.log(u2 * np.pi * 2)
|
||||
gamma_log = np.log(np.sqrt(1 - np.exp(-np.exp(nu_log))**2))
|
||||
return np.array([nu_log, theta_log, gamma_log])
|
||||
|
||||
self.params_log = self.add_weight(
|
||||
name='params_log', shape=(3, self.units), initializer=initializer
|
||||
)
|
||||
|
||||
@recompute_grad
|
||||
def call(self, inputs, mask=None):
|
||||
u = self.i_dense(inputs)
|
||||
params = K.exp(self.params_log)
|
||||
nu, theta, gamma = params[0], params[1], params[2]
|
||||
|
||||
if self.unroll:
|
||||
L_in = K.int_shape(u)[1]
|
||||
assert L_in is not None, 'input_length can not be None while unroll=True'
|
||||
log2_L = int(np.ceil(np.log2(L_in)))
|
||||
else:
|
||||
L_in = K.shape(u)[1]
|
||||
log2_L = K.log(K.cast(L_in, K.floatx())) / K.log(2.)
|
||||
log2_L = K.cast(tf.ceil(log2_L), 'int32')
|
||||
|
||||
u = tf.complex(u[..., ::2], u[..., 1::2])
|
||||
u = tf.pad(u, [[0, 0], [0, 2**log2_L - K.shape(u)[1]], [0, 0]])
|
||||
B, L, D = K.shape(u)[0], K.shape(u)[1], K.int_shape(u)[-1]
|
||||
|
||||
def lru(i, x):
|
||||
l = 2**i
|
||||
x = K.reshape(x, [B * L // l, l, D])
|
||||
x1, x2 = x[:, :l // 2], x[:, l // 2:]
|
||||
|
||||
pos = K.arange(1, l // 2 + 1, dtype=K.floatx())
|
||||
nus = tf.einsum('n,d->nd', pos, nu)
|
||||
thetas = tf.einsum('n,d->nd', pos, theta)
|
||||
lambs = K.exp(tf.complex(-nus, thetas))
|
||||
|
||||
x2 = x2 + lambs * x1[:, -1:]
|
||||
x = K.concatenate([x1, x2], axis=1)
|
||||
if (not self.unroll) and K.int_shape(u)[1] is not None:
|
||||
x = K.reshape(x, [B, L, D])
|
||||
|
||||
return i + 1, x
|
||||
|
||||
if self.unroll:
|
||||
x = u
|
||||
for i in range(log2_L):
|
||||
_, x = lru(i + 1, x)
|
||||
else:
|
||||
_, x = tf.while_loop(lambda i, x: i <= log2_L, lru, [1, u])
|
||||
|
||||
x = x[:, :L_in] * tf.complex(gamma, 0.)
|
||||
x = K.concatenate([tf.real(x), tf.imag(x)], axis=-1)
|
||||
return self.o_dense(x)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'units': self.units,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'unroll': self.unroll,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
}
|
||||
base_config = super(LRU, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
#! -*- coding: utf-8 -*-
|
||||
# RNN-α 模型实现
|
||||
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
|
||||
|
||||
from bert4keras.models import *
|
||||
from lru import LRU
|
||||
from slru import SLRU
|
||||
from rwkv import RWKV
|
||||
|
||||
RNN = LRU # SLRU、RWKV
|
||||
|
||||
|
||||
class RNN_alpha(RoFormerV2):
|
||||
"""RNN-α
|
||||
改动:基本模块换成RNN
|
||||
"""
|
||||
def initializer(self, shape, dtype=None, order=2, gain=1.0):
|
||||
return super(RNN_alpha, self).initializer(shape, dtype, order, gain)
|
||||
|
||||
def apply_main_layers(self, inputs, index):
|
||||
"""RNN-α 的主体是基于RNN的模块
|
||||
顺序:RNN --> Add --> LN --> FFN --> Add --> LN
|
||||
"""
|
||||
x = inputs
|
||||
rnn_name = 'Transformer-%d-RNN' % index
|
||||
ffn_name = 'Transformer-%d-FFN' % index
|
||||
|
||||
xi = x
|
||||
x = self.apply(
|
||||
inputs=x,
|
||||
layer=RNN,
|
||||
units=(2 if RNN is SLRU else 1) * self.hidden_size,
|
||||
use_bias=False,
|
||||
kernel_initializer=self.initializer,
|
||||
name=rnn_name
|
||||
)
|
||||
x = self.apply(
|
||||
inputs=x,
|
||||
layer=Dropout,
|
||||
rate=self.dropout_rate,
|
||||
name='%s-Dropout' % rnn_name
|
||||
)
|
||||
x = self.apply(inputs=[xi, x], layer=Add, name='%s-Add' % rnn_name)
|
||||
x = self.apply(
|
||||
inputs=x,
|
||||
layer=LayerNormalization,
|
||||
zero_mean=False,
|
||||
scale=False,
|
||||
offset=False,
|
||||
epsilon=1e-12,
|
||||
name='%s-Norm' % rnn_name
|
||||
)
|
||||
|
||||
xi = x
|
||||
x = self.apply(
|
||||
inputs=x,
|
||||
layer=FeedForward,
|
||||
units=self.intermediate_size,
|
||||
kernel_initializer=self.initializer,
|
||||
use_bias=False,
|
||||
name=ffn_name
|
||||
)
|
||||
x = self.apply(
|
||||
inputs=x,
|
||||
layer=Dropout,
|
||||
rate=self.dropout_rate,
|
||||
name='%s-Dropout' % ffn_name
|
||||
)
|
||||
x = self.apply(inputs=[xi, x], layer=Add, name='%s-Add' % rnn_name)
|
||||
x = self.apply(
|
||||
inputs=x,
|
||||
layer=LayerNormalization,
|
||||
zero_mean=False,
|
||||
scale=False,
|
||||
offset=False,
|
||||
epsilon=1e-12,
|
||||
name='%s-Norm' % ffn_name
|
||||
)
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
#! -*- coding: utf-8 -*-
|
||||
# RWKV
|
||||
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
|
||||
|
||||
from bert4keras.layers import *
|
||||
|
||||
|
||||
class RWKV(Layer):
|
||||
"""RWKV
|
||||
链接1:https://github.com/BlinkDL/RWKV-LM
|
||||
链接2:https://kexue.fm/archives/9554
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
units,
|
||||
use_bias=True,
|
||||
unroll=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
**kwargs
|
||||
):
|
||||
super(RWKV, self).__init__(**kwargs)
|
||||
self.units = units
|
||||
self.use_bias = use_bias
|
||||
self.unroll = unroll
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
|
||||
@integerize_shape
|
||||
def build(self, input_shape):
|
||||
super(RWKV, self).build(input_shape)
|
||||
hidden_size = input_shape[-1]
|
||||
self.rkv_dense = Dense(
|
||||
units=self.units * 3,
|
||||
use_bias=self.use_bias,
|
||||
kernel_initializer=self.kernel_initializer
|
||||
)
|
||||
self.o_dense = Dense(
|
||||
units=hidden_size,
|
||||
use_bias=self.use_bias,
|
||||
kernel_initializer=self.kernel_initializer
|
||||
)
|
||||
|
||||
def initializer(shape, dtype=None):
|
||||
r_min, r_max = 0.9, 0.999
|
||||
u = np.random.random(size=shape)
|
||||
return np.log(-0.5 * np.log(u * (r_max**2 - r_min**2) + r_min**2))
|
||||
|
||||
self.nu_log = self.add_weight(
|
||||
name='nu_log', shape=(self.units,), initializer=initializer
|
||||
)
|
||||
self.gamma_log = self.add_weight(
|
||||
name='gamma_log', shape=(self.units,), initializer='zeros'
|
||||
)
|
||||
|
||||
@recompute_grad
|
||||
def call(self, inputs, mask=None):
|
||||
rkv = self.rkv_dense(inputs)
|
||||
r, k, v = tf.split(rkv, 3, axis=-1)
|
||||
r, k = K.sigmoid(r), K.exp(k)
|
||||
kv = k * v
|
||||
u = K.concatenate([kv, k], axis=-1)
|
||||
nu = K.exp(K.concatenate([self.nu_log, self.nu_log], axis=0))
|
||||
gamma = K.exp(self.nu_log + self.gamma_log) - 1
|
||||
|
||||
if self.unroll:
|
||||
L_in = K.int_shape(u)[1]
|
||||
assert L_in is not None, 'input_length can not be None while unroll=True'
|
||||
log2_L = int(np.ceil(np.log2(L_in)))
|
||||
else:
|
||||
L_in = K.shape(u)[1]
|
||||
log2_L = K.log(K.cast(L_in, K.floatx())) / K.log(2.)
|
||||
log2_L = K.cast(tf.ceil(log2_L), 'int32')
|
||||
|
||||
u = tf.pad(u, [[0, 0], [0, 2**log2_L - K.shape(u)[1]], [0, 0]])
|
||||
B, L, D = K.shape(u)[0], K.shape(u)[1], K.int_shape(u)[-1]
|
||||
|
||||
def rwkv(i, x):
|
||||
l = 2**i
|
||||
x = K.reshape(x, [B * L // l, l, D])
|
||||
x1, x2 = x[:, :l // 2], x[:, l // 2:]
|
||||
|
||||
pos = K.arange(1, l // 2 + 1, dtype=K.floatx())
|
||||
nus = tf.einsum('n,d->nd', pos, nu)
|
||||
lambs = K.exp(-nus)
|
||||
|
||||
x2 = x2 + lambs * x1[:, -1:]
|
||||
x = K.concatenate([x1, x2], axis=1)
|
||||
if (not self.unroll) and K.int_shape(u)[1] is not None:
|
||||
x = K.reshape(x, [B, L, D])
|
||||
|
||||
return i + 1, x
|
||||
|
||||
if self.unroll:
|
||||
for i in range(log2_L):
|
||||
_, u = rwkv(i + 1, u)
|
||||
else:
|
||||
_, u = tf.while_loop(lambda i, x: i <= log2_L, rwkv, [1, u])
|
||||
|
||||
u1, u2 = tf.split(u[:, :L_in], 2, axis=-1)
|
||||
u = tf.math.divide_no_nan(u1 + gamma * kv, u2 + gamma * k) * r
|
||||
return self.o_dense(u)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'units': self.units,
|
||||
'use_bias': self.use_bias,
|
||||
'unroll': self.unroll,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
}
|
||||
base_config = super(RWKV, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
#! -*- coding: utf-8 -*-
|
||||
# 简化版线性循环单元(Simpler Linear Recurrent Unit)
|
||||
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
|
||||
|
||||
from bert4keras.layers import *
|
||||
|
||||
|
||||
class SLRU(Layer):
|
||||
"""实数版线性循环单元
|
||||
链接1:https://arxiv.org/abs/2303.06349
|
||||
链接2:https://kexue.fm/archives/9554
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
units,
|
||||
activation='linear',
|
||||
use_bias=True,
|
||||
unroll=True, # unroll可以加速训练,但是会增加显存消耗
|
||||
kernel_initializer='glorot_uniform',
|
||||
**kwargs
|
||||
):
|
||||
super(SLRU, self).__init__(**kwargs)
|
||||
self.units = units
|
||||
self.activation = activations.get(activation)
|
||||
self.use_bias = use_bias
|
||||
self.unroll = unroll
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
|
||||
@integerize_shape
|
||||
def build(self, input_shape):
|
||||
super(SLRU, self).build(input_shape)
|
||||
hidden_size = input_shape[-1]
|
||||
self.i_dense = Dense(
|
||||
units=self.units,
|
||||
use_bias=self.use_bias,
|
||||
kernel_initializer=self.kernel_initializer
|
||||
)
|
||||
self.o_dense = Dense(
|
||||
units=hidden_size,
|
||||
use_bias=self.use_bias,
|
||||
activation=self.activation,
|
||||
kernel_initializer=self.kernel_initializer
|
||||
)
|
||||
|
||||
def initializer(shape, dtype=None):
|
||||
r_min, r_max = 0.9, 0.999
|
||||
u = np.random.random(size=shape[1])
|
||||
nu_log = np.log(-0.5 * np.log(u * (r_max**2 - r_min**2) + r_min**2))
|
||||
gamma_log = np.log(np.sqrt(1 - np.exp(-np.exp(nu_log))**2))
|
||||
return np.array([nu_log, gamma_log])
|
||||
|
||||
self.params_log = self.add_weight(
|
||||
name='params_log', shape=(2, self.units), initializer=initializer
|
||||
)
|
||||
|
||||
@recompute_grad
|
||||
def call(self, inputs, mask=None):
|
||||
u = self.i_dense(inputs)
|
||||
params = K.exp(self.params_log)
|
||||
nu, gamma = params[0], params[1]
|
||||
|
||||
if self.unroll:
|
||||
L_in = K.int_shape(u)[1]
|
||||
assert L_in is not None, 'input_length can not be None while unroll=True'
|
||||
log2_L = int(np.ceil(np.log2(L_in)))
|
||||
else:
|
||||
L_in = K.shape(u)[1]
|
||||
log2_L = K.log(K.cast(L_in, K.floatx())) / K.log(2.)
|
||||
log2_L = K.cast(tf.ceil(log2_L), 'int32')
|
||||
|
||||
u = tf.pad(u, [[0, 0], [0, 2**log2_L - K.shape(u)[1]], [0, 0]])
|
||||
B, L, D = K.shape(u)[0], K.shape(u)[1], K.int_shape(u)[-1]
|
||||
|
||||
def lru(i, x):
|
||||
l = 2**i
|
||||
x = K.reshape(x, [B * L // l, l, D])
|
||||
x1, x2 = x[:, :l // 2], x[:, l // 2:]
|
||||
|
||||
pos = K.arange(1, l // 2 + 1, dtype=K.floatx())
|
||||
nus = tf.einsum('n,d->nd', pos, nu)
|
||||
lambs = K.exp(-nus)
|
||||
|
||||
x2 = x2 + lambs * x1[:, -1:]
|
||||
x = K.concatenate([x1, x2], axis=1)
|
||||
if (not self.unroll) and K.int_shape(u)[1] is not None:
|
||||
x = K.reshape(x, [B, L, D])
|
||||
|
||||
return i + 1, x
|
||||
|
||||
if self.unroll:
|
||||
x = u
|
||||
for i in range(log2_L):
|
||||
_, x = lru(i + 1, x)
|
||||
else:
|
||||
_, x = tf.while_loop(lambda i, x: i <= log2_L, lru, [1, u])
|
||||
|
||||
x = x[:, :L_in] * gamma
|
||||
return self.o_dense(x)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'units': self.units,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'unroll': self.unroll,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
}
|
||||
base_config = super(SLRU, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
Loading…
Reference in New Issue