In an attempt to learn Tensorflow, I have implemented Recurrent Batch Normalization for the pixel-by-pixel MNIST classification using Tensorflow 1.13.
- A batch normalization operation for each time step has been implemented based on the discussion.
- A ModelRunner class is added to control the pipeline of model training and evaluation: evaluate the performance on testset only when the lowest-by-far validation loss has been achieved.
Suppose we want to run 50 epochs and use Tensorboard to visualize the process
cd bn_lstm
python main.py --write_summary True --max_epoch 50
To check the description of all flags
python main.py -helpful
To open tensorboard
tensorboard --logdir=path
where path can be found in the log which shows the relative dir where the model is saved, e.g. logs/ModelWrapper/lr-0.001_dim-32/20190912-230850/saved_model/tfb_dir.
The comparison of the accuracy on test set is shown in below graph
The red curve corresponds to the acc of the batch normalized model.
The orange curve corresponds to the training loss of the batch normalized model.
# Apply batch norm
python main.py --write_summary True --max_epoch 50 --rnn_dim 32
# Apply no batch norm
python main.py --write_summary True --max_epoch 50 --rnn_dim 32 --batch_norm False
tensorflow==1.13.1
Although I have not tested, I guess it should be working under tf 1.12 and tf 1.14 as well.