DeepRec-AI/DeepRec

Masknet replace batchnorm with layernorm

zippeurfou opened this issue · 6 comments

The paper from masknet uses layernorm.
however the code implementation uses batchnorm.

Hi Marc,

Thanks for bringing this up! This is indeed a bug, and we are fixing it.

Hi Marc,

Upon checking, this is not a bug. When applying BatchNorm on the default axis (last dim), BatchNorm reduces to LayerNorm, and since the size of gamma/beta depends on the shape of input tensor, the original implementation is still correct.

However, for the clarity of the code, we updated the example (ref PR #816 ).

Thanks for the comment!

I am not sure I am following see this screenshot.
Screenshot 2023-04-18 at 11 41 50 AM
What am I missing?

Because your code isn't in trianing.

tf.layers.batch_normalization() will call to class BatchNormalizationBase

class BatchNormalizationBase(Layer):

tf.keras.layers.LayerNormalization() will call to class LayerNormalization
class LayerNormalization(Layer):

In LayerNormalization, mean and var are computed by nn.moments

mean, variance = nn.moments(inputs, self.axis, keep_dims=True)

then use nn.batch_normalization to get the result.
outputs = nn.batch_normalization(
inputs,
mean,
variance,
offset=offset,
scale=scale,
variance_epsilon=self.epsilon)

It is the same with BN without other features.

def _moments(self, inputs, reduction_axes, keep_dims):
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
inputs_size = array_ops.size(inputs)
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
variance = array_ops.where(inputs_size > 0, variance,
K.zeros_like(variance))
return mean, variance

mean, variance = self._moments(
math_ops.cast(inputs, self._param_dtype),
reduction_axes,
keep_dims=keep_dims)

outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),
offset,
scale,
self.epsilon)

But the difference is that when you are not in training, the mean and var of BN will be replaced.

mean = tf_utils.smart_cond(training,
lambda: mean,
lambda: ops.convert_to_tensor(moving_mean))
variance = tf_utils.smart_cond(
training,
lambda: variance,
lambda: ops.convert_to_tensor(moving_variance))

you can add input param moving_mean_initializer='ones' which is defaulted to 'zeros' and find output is changed.

Thanks @Duyi-Wang it makes sense. I was confused by it as well but the doc clearly state it. Thanks for pointing out the code.
Adding a screenshot for posterity.
Screenshot 2023-04-19 at 10 45 40 AM
Feel free to close this one.