cbfinn/maml

Batch normalization

Closed this issue · 8 comments

srpv commented

Hello, Chelsea.

The batch normalization documentation says that this ops is not attached to the TensorFlow graph by default. So, there're two ways to force the updates during training:

  • explicitly tell the graph to update ops in tf.GraphKeys.UPDATE_OPS
  • or set updates_collections parameter of batch_norm to None.

I don't see neither of those in the code. Maybe I'm missing something.

I haven't been able to make the first way work due to while cycle in map_fn function. But the second modification is easy and seems to work. Although, I'm not sure I see any difference in performance.

I compute the test-time statistics using the test batch of data, instead of computing the average training statistics. This doesn't require keeping track of batch norm training statistics. [Note that train is always set to True when calling the batch_norm function, which means that tensorflow will compute the statistics using the current batch]

It's possible that it would work better by using training batch statistics, but I haven't tried it.

srpv commented

It's not the issue I'm talking about. See the Note from tf.contrib.layers.batch_norm page

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.

In other words, without these steps I wrote in the issue, moving_mean and moving_variance doesn't update at all (even during training). Again, maybe I'm missing some other way you're updating them.

You only need to update moving_mean and moving_variance if you use them. In this case, the batch norm statistics are being computed using the batch data rather than a moving average of the statistics (so they don't need to be updated).

srpv commented

OK. Indeed, they're needed only during testing.
Thanks.

@cbfinn, as you mentioned before,

I compute the test-time statistics using the test batch of data, instead of computing the average training statistics

This seems to be a bit of cheating especially on test-time. In general, we can assume evaluating only one sample at a time on test-time and then there is no way to get proper statistics for batch_norm. This means the test-set performance will partially dependent on the size of batch.

Hi, I see your approcach.
If I use moving average of the statistics by adding update_op into train ops, Then Need I set train=FALSE when testing use batch_norm function?