google/e3d_lstm

Cannot get your mse score when I train 80k epoches?

Closed this issue · 1 comments

I train the movingmnist dataset with 80k epoches, but at last when I test the checkpoint which I save, I cannot get the perfect score.
I get the score as below:

2019-12-07 14:34:19 test...
mse per seq: 775.0606184674647
31.52545323113879
44.8356685294417
56.46880273685188
67.7944408103316
77.83094204308274
85.80292131045539
93.0013996116623
100.49140779575508
106.04341360944545
111.2661687892998
psnr per frame: 17.902477
21.60128
20.07348
19.040869
18.222927
17.590435
17.141094
16.781181
16.430117
16.174725
15.968664

I just set the params in the run.py as below:

FLAGS.DEFINE_string('train_data_paths', 'data/moving-mnist-example/moving-mnist-train.npz', 'train data paths.')
FLAGS.DEFINE_string('valid_data_paths', 'data/moving-mnist-example/moving-mnist-valid.npz', 'validation data paths.')
FLAGS.DEFINE_string('save_dir', 'checkpoints/_mnist_e3d_lstm', 'dir to store trained net.')
FLAGS.DEFINE_string('gen_frm_dir', 'results/_mnist_e3d_lstm', 'dir to store result.')
FLAGS.DEFINE_string('logdir', './Summary', 'dir to store summary.')

FLAGS.DEFINE_boolean('is_Training', True, 'training or testing')
FLAGS.DEFINE_string('dataset_name', 'mnist', 'The name of dataset.')
FLAGS.DEFINE_integer('input_length', 10, 'input length.')
FLAGS.DEFINE_integer('total_length', 20, 'total input and output length.')
FLAGS.DEFINE_integer('img_width', 64, 'input image width.')
FLAGS.DEFINE_integer('img_channel', 1, 'number of image channel.')
FLAGS.DEFINE_integer('patch_size', 4, 'patch size on one dimension.')
FLAGS.DEFINE_boolean('reverse_input', False,
                     'reverse the input/outputs during training.')

FLAGS.DEFINE_string('model_name', 'e3d_lstm', 'The name of the architecture.')
FLAGS.DEFINE_string('pretrained_model', '', '.ckpt file to initialize from.')
FLAGS.DEFINE_string('num_hidden', '64,64,64,64',
                    'COMMA separated number of units of e3d lstms.')
FLAGS.DEFINE_integer('filter_size', 5, 'filter of a e3d lstm layer.')
FLAGS.DEFINE_boolean('layer_norm', True, 'whether to apply tensor layer norm.')

FLAGS.DEFINE_boolean('scheduled_sampling', True, 'for scheduled sampling')
FLAGS.DEFINE_integer('sampling_stop_iter', 50000, 'for scheduled sampling.')
FLAGS.DEFINE_float('sampling_start_value', 1.0, 'for scheduled sampling.')
FLAGS.DEFINE_float('sampling_changing_rate', 0.00002, 'for scheduled sampling.')

FLAGS.DEFINE_float('lr', 0.001, 'learning rate.')
FLAGS.DEFINE_integer('batch_size', 4, 'batch size for training.')
FLAGS.DEFINE_integer('max_iterations', 80000, 'max num of steps.')
FLAGS.DEFINE_integer('display_interval', 1,
                     'number of iters showing training loss.')
FLAGS.DEFINE_integer('test_interval', 1000, 'number of iters for test.')
FLAGS.DEFINE_integer('snapshot_interval', 1000,
                     'number of iters saving models.')
FLAGS.DEFINE_integer('num_save_samples', 10, 'number of sequences to be saved.')
FLAGS.DEFINE_integer('n_gpu', 1,
                     'how many GPUs to distribute the training across.')
FLAGS.DEFINE_boolean('allow_gpu_growth', True, 'allow gpu growth')

I test your pretrain model, I got the same mse in your paper:

2019-12-08 19:01:27 itr: 1
training loss: 1674.2725
2019-12-08 19:01:27 test...
mse per seq: 414.5043303178164
22.084744220267318
27.563711177848862
32.04450662723763
36.26114293711936
39.79543918693711
43.05908385832945
46.98905318724607
51.02225838013307
55.17859578562643
60.505794957071124
psnr per frame: 20.58732
23.137892
22.231098
21.548489
20.990143
20.562159
20.216097
19.843668
19.465357
19.1416
18.736696

I show some tensorboard result in my training e3d_lstm process:
image
I get the lowest train_loss is 1970 when its step on 20.25k.

And I find that I cannot get the loss as your pretrain model (loss_train:1674.2725).
So can I ask what the params you set?

wyb15 commented

We noticed that there is a bug in the current code about "global_memory" in "rnn_cell.py" which may cause bad training results on both datasets. We are working on fixing this issue and refreshing our pre-trained models. I will temporally merge this issue to #1.