carlthome/tensorflow-convlstm-cell

Evaluate layernorm importance

carlthome opened this issue · 1 comments

It feels to me like layernorm is extremely important when you stack several ConvLSTM layers but I'm not sure. It would be interesting to compare versions with regular bias vs. layernorm. I also wonder if it differs when you have skip-connections between layers (e.g. tf.nn.rnn_cell.ResidualWrapper) or simply stack them.

In terms of computational time layernorm is quite a bit slower (like 30%) but that should be properly benchmarked too.

It seems to me that layer norm helps stabilize training but hurts the final test accuracy. In the layer norm paper, the authors warn against using it in CNNs because it explicitly enforces each channel to have similar importance. This is contrary to the notion of CNN layers doing feature extraction.

I have personally found that replacing layer norm with group norm retains the training stability aspects of layer norm while increasing test accuracy. Group norm enables groups of channels to have different importance.

I implemented it like so:

def group_norm(x, scope, G=8, esp=1e-5):
    with tf.variable_scope('{}_norm'.format(scope)):
        # normalize
        # tranpose: [bs, h, w, c] to [bs, c, h, w] following the paper
        x = tf.transpose(x, [0, 3, 1, 2])
        N, C, H, W = x.get_shape().as_list()
        G = min(G, C)
        x = tf.reshape(x, [-1, G, C // G, H, W])
        mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
        x = (x - mean) / tf.sqrt(var + esp)
        # per channel gamma and beta
        zeros = lambda: tf.zeros([C], dtype=tf.float32)
        ones = lambda: tf.ones([C], dtype=tf.float32)
        gamma = tf.Variable(initial_value = ones, dtype=tf.float32, name='gamma')
        beta = tf.Variable(initial_value = zeros, dtype=tf.float32, name='beta')
        gamma = tf.reshape(gamma, [1, C, 1, 1])
        beta = tf.reshape(beta, [1, C, 1, 1])

        output = tf.reshape(x, [-1, C, H, W]) * gamma + beta
        # tranpose: [bs, c, h, w, c] to [bs, h, w, c] following the paper
        output = tf.transpose(output, [0, 2, 3, 1])
    return output

and

 r = group_norm(r, "gates_r", G = 6, esp = 1e-5)
 u = group_norm(u, "gates_u", G = 6, esp = 1e-5)