一些改造版RNN

This commit is contained in:
kevinding1125 2023-06-13 19:17:33 +08:00
parent d3b671b006
commit 418b6e6ee1
7 changed files with 446 additions and 1 deletions

View File

@ -1,5 +1,6 @@
package com.atguigu.spark.core.rdd.operator.transform package com.atguigu.spark.core.rdd.operator.transform
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.{SparkConf, SparkContext}
object Spark01_RDD_Operator_Transform { object Spark01_RDD_Operator_Transform {
@ -11,7 +12,7 @@ object Spark01_RDD_Operator_Transform {
val sc =new SparkContext(sparkConf) val sc =new SparkContext(sparkConf)
//TODO 算子 - map //TODO 算子 - map
val rdd = sc.makeRDD( val rdd: RDD[Int] = sc.makeRDD(
List(1,2,3,4) List(1,2,3,4)
) )

View File

@ -0,0 +1,18 @@
# 线性RNN的相关变体
用bert4keras实现三个快速可并行的RNN变体LRU、SLRU和RWKV。
## 简介
- 中文博客https://kexue.fm/archives/9554
- LRU论文https://arxiv.org/abs/2303.06349
- RWKV链接https://github.com/BlinkDL/RWKV-LM
## 并行
线性RNN支持并行算法可以将O(L)的运算降低到O(log L)本项目利用的是prefix sum问题的“Upper/Lower算法”来实现RNN并行。
具体细节可以参考中文博客的“[并行化](https://kexue.fm/archives/9554#%E5%B9%B6%E8%A1%8C%E5%8C%96)”一节
## 交流
QQ交流群808623966微信群请加机器人微信号spaces_ac_cn

View File

@ -0,0 +1,8 @@
#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/6/13 19:13
@Usage :
@Desc :
'''

View File

@ -0,0 +1,117 @@
#! -*- coding: utf-8 -*-
# 线性循环单元Linear Recurrent Unit
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
from bert4keras.layers import *
class LRU(Layer):
"""线性循环单元
链接1https://arxiv.org/abs/2303.06349
链接2https://kexue.fm/archives/9554
"""
def __init__(
self,
units,
activation='linear',
use_bias=True,
unroll=True, # unroll可以加速训练但是会增加显存消耗
kernel_initializer='glorot_uniform',
**kwargs
):
super(LRU, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.unroll = unroll
self.kernel_initializer = initializers.get(kernel_initializer)
@integerize_shape
def build(self, input_shape):
super(LRU, self).build(input_shape)
hidden_size = input_shape[-1]
self.i_dense = Dense(
units=self.units * 2,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
self.o_dense = Dense(
units=hidden_size,
use_bias=self.use_bias,
activation=self.activation,
kernel_initializer=self.kernel_initializer
)
def initializer(shape, dtype=None):
r_min, r_max = 0.9, 0.999
u1 = np.random.random(size=shape[1])
u2 = np.random.random(size=shape[1])
nu_log = np.log(
-0.5 * np.log(u1 * (r_max**2 - r_min**2) + r_min**2)
)
theta_log = np.log(u2 * np.pi * 2)
gamma_log = np.log(np.sqrt(1 - np.exp(-np.exp(nu_log))**2))
return np.array([nu_log, theta_log, gamma_log])
self.params_log = self.add_weight(
name='params_log', shape=(3, self.units), initializer=initializer
)
@recompute_grad
def call(self, inputs, mask=None):
u = self.i_dense(inputs)
params = K.exp(self.params_log)
nu, theta, gamma = params[0], params[1], params[2]
if self.unroll:
L_in = K.int_shape(u)[1]
assert L_in is not None, 'input_length can not be None while unroll=True'
log2_L = int(np.ceil(np.log2(L_in)))
else:
L_in = K.shape(u)[1]
log2_L = K.log(K.cast(L_in, K.floatx())) / K.log(2.)
log2_L = K.cast(tf.ceil(log2_L), 'int32')
u = tf.complex(u[..., ::2], u[..., 1::2])
u = tf.pad(u, [[0, 0], [0, 2**log2_L - K.shape(u)[1]], [0, 0]])
B, L, D = K.shape(u)[0], K.shape(u)[1], K.int_shape(u)[-1]
def lru(i, x):
l = 2**i
x = K.reshape(x, [B * L // l, l, D])
x1, x2 = x[:, :l // 2], x[:, l // 2:]
pos = K.arange(1, l // 2 + 1, dtype=K.floatx())
nus = tf.einsum('n,d->nd', pos, nu)
thetas = tf.einsum('n,d->nd', pos, theta)
lambs = K.exp(tf.complex(-nus, thetas))
x2 = x2 + lambs * x1[:, -1:]
x = K.concatenate([x1, x2], axis=1)
if (not self.unroll) and K.int_shape(u)[1] is not None:
x = K.reshape(x, [B, L, D])
return i + 1, x
if self.unroll:
x = u
for i in range(log2_L):
_, x = lru(i + 1, x)
else:
_, x = tf.while_loop(lambda i, x: i <= log2_L, lru, [1, u])
x = x[:, :L_in] * tf.complex(gamma, 0.)
x = K.concatenate([tf.real(x), tf.imag(x)], axis=-1)
return self.o_dense(x)
def get_config(self):
config = {
'units': self.units,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'unroll': self.unroll,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
}
base_config = super(LRU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -0,0 +1,80 @@
#! -*- coding: utf-8 -*-
# RNN-α 模型实现
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
from bert4keras.models import *
from lru import LRU
from slru import SLRU
from rwkv import RWKV
RNN = LRU # SLRU、RWKV
class RNN_alpha(RoFormerV2):
"""RNN-α
改动基本模块换成RNN
"""
def initializer(self, shape, dtype=None, order=2, gain=1.0):
return super(RNN_alpha, self).initializer(shape, dtype, order, gain)
def apply_main_layers(self, inputs, index):
"""RNN-α 的主体是基于RNN的模块
顺序RNN --> Add --> LN --> FFN --> Add --> LN
"""
x = inputs
rnn_name = 'Transformer-%d-RNN' % index
ffn_name = 'Transformer-%d-FFN' % index
xi = x
x = self.apply(
inputs=x,
layer=RNN,
units=(2 if RNN is SLRU else 1) * self.hidden_size,
use_bias=False,
kernel_initializer=self.initializer,
name=rnn_name
)
x = self.apply(
inputs=x,
layer=Dropout,
rate=self.dropout_rate,
name='%s-Dropout' % rnn_name
)
x = self.apply(inputs=[xi, x], layer=Add, name='%s-Add' % rnn_name)
x = self.apply(
inputs=x,
layer=LayerNormalization,
zero_mean=False,
scale=False,
offset=False,
epsilon=1e-12,
name='%s-Norm' % rnn_name
)
xi = x
x = self.apply(
inputs=x,
layer=FeedForward,
units=self.intermediate_size,
kernel_initializer=self.initializer,
use_bias=False,
name=ffn_name
)
x = self.apply(
inputs=x,
layer=Dropout,
rate=self.dropout_rate,
name='%s-Dropout' % ffn_name
)
x = self.apply(inputs=[xi, x], layer=Add, name='%s-Add' % rnn_name)
x = self.apply(
inputs=x,
layer=LayerNormalization,
zero_mean=False,
scale=False,
offset=False,
epsilon=1e-12,
name='%s-Norm' % ffn_name
)
return x

View File

@ -0,0 +1,111 @@
#! -*- coding: utf-8 -*-
# RWKV
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
from bert4keras.layers import *
class RWKV(Layer):
"""RWKV
链接1https://github.com/BlinkDL/RWKV-LM
链接2https://kexue.fm/archives/9554
"""
def __init__(
self,
units,
use_bias=True,
unroll=True,
kernel_initializer='glorot_uniform',
**kwargs
):
super(RWKV, self).__init__(**kwargs)
self.units = units
self.use_bias = use_bias
self.unroll = unroll
self.kernel_initializer = initializers.get(kernel_initializer)
@integerize_shape
def build(self, input_shape):
super(RWKV, self).build(input_shape)
hidden_size = input_shape[-1]
self.rkv_dense = Dense(
units=self.units * 3,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
self.o_dense = Dense(
units=hidden_size,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
def initializer(shape, dtype=None):
r_min, r_max = 0.9, 0.999
u = np.random.random(size=shape)
return np.log(-0.5 * np.log(u * (r_max**2 - r_min**2) + r_min**2))
self.nu_log = self.add_weight(
name='nu_log', shape=(self.units,), initializer=initializer
)
self.gamma_log = self.add_weight(
name='gamma_log', shape=(self.units,), initializer='zeros'
)
@recompute_grad
def call(self, inputs, mask=None):
rkv = self.rkv_dense(inputs)
r, k, v = tf.split(rkv, 3, axis=-1)
r, k = K.sigmoid(r), K.exp(k)
kv = k * v
u = K.concatenate([kv, k], axis=-1)
nu = K.exp(K.concatenate([self.nu_log, self.nu_log], axis=0))
gamma = K.exp(self.nu_log + self.gamma_log) - 1
if self.unroll:
L_in = K.int_shape(u)[1]
assert L_in is not None, 'input_length can not be None while unroll=True'
log2_L = int(np.ceil(np.log2(L_in)))
else:
L_in = K.shape(u)[1]
log2_L = K.log(K.cast(L_in, K.floatx())) / K.log(2.)
log2_L = K.cast(tf.ceil(log2_L), 'int32')
u = tf.pad(u, [[0, 0], [0, 2**log2_L - K.shape(u)[1]], [0, 0]])
B, L, D = K.shape(u)[0], K.shape(u)[1], K.int_shape(u)[-1]
def rwkv(i, x):
l = 2**i
x = K.reshape(x, [B * L // l, l, D])
x1, x2 = x[:, :l // 2], x[:, l // 2:]
pos = K.arange(1, l // 2 + 1, dtype=K.floatx())
nus = tf.einsum('n,d->nd', pos, nu)
lambs = K.exp(-nus)
x2 = x2 + lambs * x1[:, -1:]
x = K.concatenate([x1, x2], axis=1)
if (not self.unroll) and K.int_shape(u)[1] is not None:
x = K.reshape(x, [B, L, D])
return i + 1, x
if self.unroll:
for i in range(log2_L):
_, u = rwkv(i + 1, u)
else:
_, u = tf.while_loop(lambda i, x: i <= log2_L, rwkv, [1, u])
u1, u2 = tf.split(u[:, :L_in], 2, axis=-1)
u = tf.math.divide_no_nan(u1 + gamma * kv, u2 + gamma * k) * r
return self.o_dense(u)
def get_config(self):
config = {
'units': self.units,
'use_bias': self.use_bias,
'unroll': self.unroll,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
}
base_config = super(RWKV, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -0,0 +1,110 @@
#! -*- coding: utf-8 -*-
# 简化版线性循环单元Simpler Linear Recurrent Unit
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
from bert4keras.layers import *
class SLRU(Layer):
"""实数版线性循环单元
链接1https://arxiv.org/abs/2303.06349
链接2https://kexue.fm/archives/9554
"""
def __init__(
self,
units,
activation='linear',
use_bias=True,
unroll=True, # unroll可以加速训练但是会增加显存消耗
kernel_initializer='glorot_uniform',
**kwargs
):
super(SLRU, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.unroll = unroll
self.kernel_initializer = initializers.get(kernel_initializer)
@integerize_shape
def build(self, input_shape):
super(SLRU, self).build(input_shape)
hidden_size = input_shape[-1]
self.i_dense = Dense(
units=self.units,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
self.o_dense = Dense(
units=hidden_size,
use_bias=self.use_bias,
activation=self.activation,
kernel_initializer=self.kernel_initializer
)
def initializer(shape, dtype=None):
r_min, r_max = 0.9, 0.999
u = np.random.random(size=shape[1])
nu_log = np.log(-0.5 * np.log(u * (r_max**2 - r_min**2) + r_min**2))
gamma_log = np.log(np.sqrt(1 - np.exp(-np.exp(nu_log))**2))
return np.array([nu_log, gamma_log])
self.params_log = self.add_weight(
name='params_log', shape=(2, self.units), initializer=initializer
)
@recompute_grad
def call(self, inputs, mask=None):
u = self.i_dense(inputs)
params = K.exp(self.params_log)
nu, gamma = params[0], params[1]
if self.unroll:
L_in = K.int_shape(u)[1]
assert L_in is not None, 'input_length can not be None while unroll=True'
log2_L = int(np.ceil(np.log2(L_in)))
else:
L_in = K.shape(u)[1]
log2_L = K.log(K.cast(L_in, K.floatx())) / K.log(2.)
log2_L = K.cast(tf.ceil(log2_L), 'int32')
u = tf.pad(u, [[0, 0], [0, 2**log2_L - K.shape(u)[1]], [0, 0]])
B, L, D = K.shape(u)[0], K.shape(u)[1], K.int_shape(u)[-1]
def lru(i, x):
l = 2**i
x = K.reshape(x, [B * L // l, l, D])
x1, x2 = x[:, :l // 2], x[:, l // 2:]
pos = K.arange(1, l // 2 + 1, dtype=K.floatx())
nus = tf.einsum('n,d->nd', pos, nu)
lambs = K.exp(-nus)
x2 = x2 + lambs * x1[:, -1:]
x = K.concatenate([x1, x2], axis=1)
if (not self.unroll) and K.int_shape(u)[1] is not None:
x = K.reshape(x, [B, L, D])
return i + 1, x
if self.unroll:
x = u
for i in range(log2_L):
_, x = lru(i + 1, x)
else:
_, x = tf.while_loop(lambda i, x: i <= log2_L, lru, [1, u])
x = x[:, :L_in] * gamma
return self.o_dense(x)
def get_config(self):
config = {
'units': self.units,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'unroll': self.unroll,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
}
base_config = super(SLRU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))