Framework for building complex recurrent neural networks with Keras
Ability to easily iterate over different neural network architectures is key to doing machine learning research. While deep learning libraries like Keras makes it very easy to prototype new layers and models, writing custom recurrent neural networks is harder than it needs to be in almost all popular deep learning libraries available today. One key missing feature in these libraries is reusable RNN cells. Most libraries provide layers (such as LSTM, GRU etc), which can only be used as is, and not be easily embedded in a bigger RNN. Writing the RNN logic itself can be tiresome at times. For example in Keras, information about the states (shape and initial value) are provided by writing two seperate functions, get_initial_states
and reset_states
(for stateful version). There are many architectures whose implementation is not trivial using modern deep learning libraries, such as:
- Synchronising the states of all the layers in a RNN stack.
- Feeding back the output of the last layer of a RNN stack to the first layer in next time step (readout).
- Decoders : RNNs who can look at the whole of the input sequence / vector at every time step.
- Teacher forcing : Using the ground truth at time t-1 for predicting at time t during training.
- Nested RNNs.
- Initializing states with different distributions.
Recurrent shop adresses these issues by letting the user write RNNs of arbitrary complexity using Keras's functional API. In other words, the user builds a standard Keras model which defines the logic of the RNN for a single timestep, and RecurrentShop converts this model into a Recurrent
instance, which is capable of processing sequences.
# The RNN logic is written using Keras's functional API.
# Which means we use Keras layers instead of theano/tensorflow ops
from keras.layers import *
from keras.models import *
from recurrentshop import *
x_t = Input(5,)) # The input to the RNN at time t
h_tm1 = Input((10,)) # Previous hidden state
# Compute new hidden state
h_t = add([Dense(10)(x_t), Dense(10, use_bias=False)(h_tm1)])
# tanh activation
h_t = Activation('tanh')(h_t)
# Build the RNN
rnn = RecurrentModel(input=x_t, initial_states=[h_tm1], output=h_t, output_states=[h_t])
# rnn is a standard Keras `Recurrent` instance. RecuurentModel also accepts arguments such as unroll, return_sequences etc
# Run the RNN over a random sequence
x = Input((7,5))
y = rnn(x)
model = Model(x, y)
model.predict(np.random.random((7, 5)))
An RNNCell
is a layer which defines the computation of an RNN for a single timestep. It takes a list of tensors as input ([input, state1_tm1, state2_tm1..]
) and outputs a list of tensors ([output, state1_t, state2_t...]
). An RNNCell does not iterate over an input sequence. It works on a single time step. So the shape of the input to an LSTMCell
would be (batch_size, input_dim)
rather than (batch_size, input_length, input_dim)
RecurrentShop comes with 3 built-in RNNCells : SimpleRNNCell
, GRUCell
, and LSTMCell
There are 2 versions of each of these cells. The basic version which is more readable which you can refer to learn how to write custom RNNCells and the customizable and recommended version which has more options like setting regularizers, constraints, activations etc.
An RNNCell
can be easily converted to a Keras Recurrent
layer:
from recurrentshop.cells import LSTMCell
lstm_cell = LSTMCell(10, input_dim=5)
lstm_layer = lstm_cell.get_layer()
# get_layer accepts arguments like return_sequences, unroll etc :
lstm_layer = lstm_cell.get_layer(return_sequences=True, unroll=True)
RecurrentSequential
is the Recurrent analog for Keras's Sequential
model. It lets you stack RNNCells and other layers such as Dense
and Activation
to build a Recurrent layer:
rnn = RecurrentSequential(unroll=False, return_sequences=False)
rnn.add(SimpleRNNCell(10, input_dim=5))
rnn.add(LSTMCell(12))
rnn.add(Dense(5))
rnn.add(GRU(8))
# rnn can now be used as regular Keras Recurrent layer.
A RecurrentSequential
(or any RecurrentModel
) can be converted to a cell using the get_cell()
method. This cell can then be added to another RecurrentSequential
.
rnn1 = RecurrentSequential()
rnn1.add(....)
rnn1.add(....)
rnn1_cell = rnn1.get_cell()
rnn2 = RecurrentSequential()
rnn2.add(rnn1_cell)
rnn2.add(...)
Since an RNNCell
is a regular Keras layer by inheritance, it can be used for building RecurrentModel
s using functional API.
from recurrentshop import *
from keras.layers import *
from keras.models import Model
input = Input((5,))
state1_tm1 = Input((10,))
state2_tm1 = Input((10,))
state3_tm1 = Input((10,))
lstm_output, state1_t, state2_t = LSTMCell(10)([input, state1_tm1, state2_tm1])
gru_output, state3_t = GRUCell(10)([input, state3_tm1])
output = add([lstm_output, gru_output])
output = Activation('tanh')(output)
rnn = RecurrentModel(input=input, initial_states=[state1_tm1, state2_tm1, state3_tm1], output=output, final_states=[state1_t, state2_t, state3_t])
See docs/ directory for more features.
git clone https://www.github.com/datalogai/recurrentshop.git
cd recurrentshop
python setup.py install
Pull requests are highly welcome.
Create an issue, with a minimal script to reproduce the problem you are facing.
Create an issue or drop me an email (fariz@datalog.ai).