keras-team/keras

[WIP] Recursive container

EderSantana opened this issue · 18 comments

What?

Over the weekend I worked on a solution to design arbitrary RNNs using Keras API. The result allow us write a vanilla RNN as:

self.input_dim = 2
self.state_dim = 2

self.model = Recursive(return_sequences=True)

self.model.add_input('input', ndim=3)  # Input is 3D tensor
self.model.add_state('h', dim=self.state_dim)
self.model.add_node(Dense(self.input_dim, self.state_dim, init='one'), 
                    name='i2h', inputs=['input', ])
self.model.add_node(Dense(self.state_dim, self.state_dim, init='orthogonal'),
                    name='h2h', inputs=['h', ])
self.model.add_node(Activation('tanh'). name='rec', inputs=['i2h', 'h2h'], 
                    merge_mode='sum', return_state='h', create_output=True)

Note that the class definition of SimpleRNN is much bigger than this and we don't have the choice of outputting intermediate values, like ex. the input-to-hidden projection. This should be interesting to design different state based models without having to dig into Theano code. I started the development on the repo I usually put my Keras extensions. There is a test here showing how to use this new container (there will be lots of printing I used for debugging. That should be cleaned up soon). If there is a general interest on this, I could just PR a new branch here.

How? (dev details)

The Recursive container is basically a Graph container with a few mods. The most important difference is the way we connect layers. Contrary to regular feedforward networks, we cannot use set_previous inside the add_node method. Everything has to be done inside a _step function and we have to take care of the order which we pass arguments to scan (I didn't explore the idea of using dictionaries as inputs to scan yet). In other words, the entire connection logic is moved from add_node (like it is done for Sequential and Graph) to _step.

Next Steps?

There is a lot of code clean up and refactoring that could possibly make the internals cleaner. For example, in a conversation with @fchollet, he also suggested me that we should define the states as self.model.add_state('h', dim=self.state_dim, input='rec') instead of using a return_state option inside the add_node.

Stateful?

Another interesting problem is how to handle stateful models, where the hidden states are not wiped out after each batch. In a previous experiment I did, I set up the initial states of an RNN to be a shared variable and defined its update to be the last state returned by scan. I did that inside the get_output method. Now that Keras gets all the individual layers self.updates, everything else was handled by the Model class. We could also do this here. The problem is that shared variables can't change sizes and we have to make sure we always have batches of the same size, otherwise, we would have to recompile the model. I would love to hear about alternatives for this.

Final words

Sorry for the long post, but hopefully it will get people interested in developing this API (and/or inspiring new ones) that I believe will make our lives much easier when it comes to design new RNNs.

I believe will make our lives much easier when it comes to design new RNNs.

I think so too. Statefulness would be especially great to have with such a model.

Recursive

Shouldn't it be Recurrent?

Shouldn't it be Recurrent?

There is a Recurrent layer already. I thought Recursive would avoid confusions. But I would also prefer Recurrent.

They are two different concepts. We can definitely solve the name collision problem one way or another (e.g. use RecurrentLayer for the pre-existing class).

@EderSantana would You mind dropping me an email? I've got few questions regarding Your fork.

@fchollet ,

I'm trying to write a GRU using the new container. Do you know how to connect multiple outputs to a multiple input container? Can I do that with Graph.add_node? Something like

model = Graph()
model.add_input('input', ndim=3)
model.add_node(Embedding(max_features, input_dim), name='x',
              input='input')
model.add_node(Dense(input_dim, state_dim, init='orthogonal'),
               name='x_h', input='x')
model.add_node(Dense(input_dim, state_dim, init='orthogonal'),
               name='x_r', input='x')
model.add_node(Dense(input_dim, state_dim, init='orthogonal'),
               name='x_z', input='x')
# So far so good, next layer is a multiple input container
model.add_node(rnn, connection_map={'x_h':'x_h', 'x_r':'x_r',
                                    'x_z':'x_z'}, merge_mode='join',
               name='rnn')
model.add_node(Dense(state_dim, 1), name='out', input='rnn',
               create_output=True)

Note that there is no connection_map for add_node, so How can I do that in a way that I can compile and train everything?

This is a case that should definitely be covered, and maybe it's time to think about a refresh of the Graph container API.

My first idea would be able to pass the connection map to the input attribute. Also we clearly need to get rid to the redundant inputs attribute.

Examples:

# single input
model.add_node(layer, input='my_input') 

# multi-input, concat / sum / multiply
model.add_node(layer, input=merge_concat(['input_1', 'input_2']))
model.add_node(layer, input=merge_sum(['input_1', 'input_2']))
# OR... which is better?
model.add_node(layer, input=['input_1', 'input_2'], merge_mode='concat')

# multi-io
model.add_node(layer, input={'my_layer_input_1': 'node_name_1', 'my_layer_input_2': 'node_name_2'})

For multi-input (e.g. concat, etc), I think this option is more elegant:

model.add_node(layer, input=merge_concat(['input_1', 'input_2']))

The downside is that it requires one more import:

from keras.layers.containers import merge_concat
pranv commented

Glad to see that you went ahead with the Recurrent Model.

Agree with @fchollet - it should be something that says recurrent and not recursive. Recurrent is a special case of recursive, recursive works at different levels of a tree structure.

And based on a quick glimpse, the _step method looks really complicated. Any ideas to simplify that?

Also, having some abstraction of gating would help.

@pranv tkx for the feedback.
Indeed the _step is where all the connection logic goes in. I'll simplify it with time (and inspiration).

Also, having some abstraction of gating would help.

What do you mean by that?

@EderSantana I thought about the container a lot. Here's what I think could be the most flexible API possible.

R = Recurrent()

R.add_states(['h1', 'h2', 'y']) # Optional!!!

R.add_input('X', is_sequence=True)

R.add_node(Identity(), name='merge', input=['X', 'y_tm1'])
R.add_node(RecurrentDense(256), name='h1', input={'prev_hid':'h1_tm1', 'input':'merge')
R.add_node(RecurrentDense(256), name='h2', input={'prev_hid':'h2_tm1', 'input':'h1')
R.add_node(RecurrentDense(256), name='y', input='h2')

R.build()
R.compile()

OK, so let me elaborate: I'd like the container to be as general as possible. For example, I'd like state h1 to have inputs from X and y. But y in turn depends on h1! How do we deal with that?

Another problem: I want to specify which layer will use the h1, and which will use h_tm1. Or maybe h_tm2?

How do we do that?

  • when add_state (which, btw. is not required, it's just for convienience) or add_node is called,
    we do not create any layers, tensors etc. We just collect information.
  • We then need to call build, which will check what layer uses what inputs, do shape-inference, allocate tensors for recurrence
  • Build sequences, non_sequences etc.
  • Buidl scan.

It's super easy to parse input and detect where we use *_tm[0-9]+ and then use this information to create _step.

Some other ideas:

  • Write RecurrentGRU which will have transform_input, so that we can pre-compute the X to improve performance.

@EderSantana Also, do you have any idea how for or if is handled in TensorFlow?

@elanmart for are is just for. See my examples here: https://github.com/EderSantana/TwistedFate/blob/master/Fibonacci.ipynb

TF has condition flow similar to theano https://github.com/fchollet/keras/blob/backend/keras/backend/tensorflow_backend.py#L340-L345

Let me check your previous code now

We have to think of the container as we think of the step function. It is just a simple and Theano(TF)-less way to write the internals of the for-loop.

If you need y, and why is calculated elsewhere, you have to define who is updated first, the state that uses y or the one that generates y. Just like you showed. I don't see a problem in there. I liked the idea of _tm1 being appended automatically. But for an outsider reader, that may be confusing. Also, I what if we need something like h_tm2? I can think of a way for doing that with your proposed API, but it will not be that obivious: generating a pass layer that just proxy h_tm1 to a state called h_tm2, but them, h_tm2 is generated explicitly while h_tm1 is not.

What you think? Should we explicitly define states a priori and always have a rule for updating them? Or is the implicit still better?

liked the idea of _tm1 being appended automatically

Nah, it's not created. It's what user inserts to let the container know, that this layer need input from previous timestep. Let's forget about it for now

We have to think of the container as we think of the step function. It is just a simple and Theano(TF)-less way to write the internals of the for-loop.

Let's say we want to implement Gated Feedback Recurrent Neural Networks.

Figure 1 will tell you everything, but long story short: given N layers, at each timestep the layer hi_t gets inputs from h1_tm1 ... hN_tm1.

How would you go about implementing this?

What I proposed is something like this:

def add_node(layer, name, inputs):
    self.nodes[name] = (layer, inputs)
    ....
def build():
    for item in self.nodes.values():   
        for var_name in item[1]:
            if re.match(r'.*_tm[0-9]+', var_name):
                state_name = re.sub(r'_tm[0-9]+', '', var_name)
                tap = int(re.sub(r'.*_tm', '', var_name))
                states[state_name].add(-tap) # this is a set                    
...
    for name in states:
        info = {
            'initial': alloc_zeros(self.nodes[name][0].output_shape[1], self.batch_size)
            'taps': list(states[name])
            }
        outputs_info.append(info)

def step(*args):
    # We know exact order in which args will come. 
    # We know the name associated with each arg (e.g. h1)
    # Create a dict = {name: arg}

    _locals = self.name_args(args)

    for node in self.nodes:
        needs = # names of tensors this guy needs
        inputs = [tensor for name, tensor in _locals.items() if name in needs]
        output = node.__call__(inputs)

        named_output = {node.name : output}
        _locals.update(named_output)

Basically

  • Do not have to specify states, their shapes etc.
  • Any node can use any other node's output from any previous timestep
  • Even when the output for node N is computed, N_tm1 is still available to other nodes
  • This may not make any sense at all, I'd like to hear your ideas on the API we should use.

Let's say we want to implement Gated Feedback Recurrent Neural Networks.

I wrote an implementation (on my pipeline to upload to Seya) of this with a regular layer. The essence of the input of each layer is the following

# given h1_tm1, h2_tm1, h3_tm1 masked, x are inputs

# Update layer 1
x1 = T.concatenate([x, h2_mask_tm1, h3_mask_tm1], axis=-1)
# get h1_t

# Update layer 2
x2 = T.concatenate([h1_mask_tm1, h1_t, h3_mask_tm1], axis=-1)
# get h2_t

# Update layer3
x3 = T.concatenate([h1_mask_tm1, h2_mask_tm1, h2_t], axis=-1)
# get h3_t

Thus in the Recurrent container, we would have _tm1 as given states and _t as calculated inside the loop. Each layer sees what is already available when comes its time to update its states.

@EderSantana @fchollet @elanmart Is there a reason development stopped on this? It looked like this got put into a list of future features for Keras but no one followed up.

Making Keras more flexible for general streams of recurrent information seems absolutely necessary, without writing a bunch of complex custom layers, for anyone experimenting with new recurrent architectures. I'm happy to join in on the development if we think we can still make it happen.

It looks like someone made a good attempt at a RecurrentContainer https://github.com/datalogai/recurrentshop

but it would still be nice to have this as a core part of Keras.