Notes on ConvLSTM
cxxixi opened this issue · 0 comments
Prescription
Since I was working on implementing the ConvLSTM model in a precipitation estimation project and there were a couple of confusing points coming up when reading the original paper by Shi et al. and the code based on the paper. Here I present some notes emulating the underlying principles behind the code and how the code illustrates those points provided by Shi et al.
import tensorflow as tf
# A new class inherited from tf.nn.rnn_cell.RNNCell
class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
def __init__(self,shape,filters,kernel,forget_bias=1.0,activation=tf.tanh,normalize=True, peehole=True, data_format='channel_last', reuse=None):
super(ConvLSTMCell,self).__init__(_reuse=reuse) #???
self._kernel = kernel
self._filters = filters
self._forget_bias = forget_bias
self._activation = activation
self._normalize = normalize
self._peehole = peehole # whether the previous layers' parameters are accessible
if data_format == 'channel_last':
# set the _size of the tensor as [spatial shape]+[num_filters].e.g, if every single input is 64*64 image, and the number of filters is 4, then the _size is [64,64,4]
self._size = tf.TensorShape(shape + [self._filters])
# ndims return the rank of the tensor or the dimension of the rank. E.g, if it's a 3D tensor, the method will return 3.
self._feature_axis = self._size.ndims
self._data_format = None`
elif data_format == 'channel_first':
self._size = tf.TensorShape(shape + [self._filters])
self._feature_axis = 0
self._data_format = 'NC'
else:
raise ValueError("Unknown data fromat")
According to the official documents
data_format
: A string or None.
Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"), or the second dimension (if data_format starts with "NC"). For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For N=3, the valid values are "NDHWC" (default) and "NCDHW".
Returns:
A Tensor with the same type as input of shape
[batch_size] + output_spatial_shape + [out_channels]
if data_format is None or does not start with "NC", or [batch_size, out_channels] + output_spatial_shape
# @property
# override the properties inherited from the parent class(RNNCell).
def state_size(self):
return tf.nn.rnn_cell.LSTMStateTuple(self._size,self._size)
def output_size(self):
return self._size
## rewrite the main method -- call
def call(self, x, state):#state, x 哪里来的
c, h = state # state is a tuple; c is the hidden state, h is the output of a whole bunches of cell operations
x = tf.concat([x,h],axis=self._feature_axis)
n = x.shape[-1].value # n: num_input_channels
m = 4* tf._filters if tf._filter>1 else 4 # m:num_output_channels; since there are four state contributing to forming the new hidden state, we multiply the number of filters by 4.
W = tf.get_variable('kernel',self._kernel+[n,m]) # here shape = [3,3,input channels, output_channels]
# compute the sum of N -d comvolution, see more here https://www.tensorflow.org/versions/master/api_docs/python/tf/nn/convolution
# x: input, W: filters
y = tf.nn.convolution(x,W,'SAME',data_format = self._data_format)
For f gate, input gate and output gate, you can see they have the similar structure and both take in the X and H(t-1) which is the previous hidden state, therefore, the author concates these two items and present it as a new X
Notice that tf.nn..convolution
is the major change Shi et al. made to original LSTM model. This operation illustrates the main point of capturing both temporal and spatial information, which is proposed by Shi et al.
The only difference between the original LSTM and convLSTM has been demonstrated in the following picture.
Any operations between Weights W and input [X,H(t-1)] in the FC_LSTM have been altered as convolutional operations.
# normalization
if not self._normalize:
y += tf.get_variable("bias",[m],initializer=tf.zeros_initializer())##zero initializer
# Splits a tensor into sub tensors.
# the shape of y is [batch_size, out_channels]+ output_spatical_shape, therefore, were gonna split output_channels into four equal parts using the feature_axis which has been declared previously.
j,i,f,o = tf.split(y, 4, axis=self._feature_axis)
#j: input contribution(hidden state); i: input_gate; f:forget_gate; o:output_gate
if self._peehole:
i += tf.get_variable('W_ci',c.shape[1:])*c # c: C(t-1), the previous cell state.
f += tf.get_variable('W_fi',c.shape[1:])*c
# c.shape[0] is the batch size dimension
If peehole
is true, we can access to the previous cell state C(t-1).
Here, i and f update themselves by adding the corresponding item W_ci/W_fi * c
# Adds a Layer Normalization layer.
if self._normalize:
j = tf.contrib.layers.layer_norm(j)
i = tf.contrib.layers.layer_norm(i)
f = tf.contrib.layers.layer_norm(f) # see more https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/layers/layer_norm
f = tf.sigmoid(f+self._forget_bias)
i = tf.sigmoid(i)
c = c*f + i*self._activation(j)
if self._peehole:
o += tf.get_variable('W_oi',c.shape[1:])*c
if self._normalize:
o = tf.contrib.layers.layer_norm(o)
c = tf.contrib.layers.layer_norm(c)
o = tf.sigmoid(o)
h = o*self._activation(c)
state = tf.nn.rnn_cell.LSTMStateTuple(c,h)
return h, state #output is the hidden state, not cell state