Cannot use test and reproduce the result?
renqianluo opened this issue · 7 comments
@dukebw Hi, thanks for your code. I download it and run it but meet with 3 main problems:
- It seems that it cannot reproduce the result in the paper? I run it using the run.sh by default and the eval ppl is around 80~100 until the end of the training (150 epochs)
- There is no test function in the Trainer class. I add one using the evaluation method (by passing self.test_data as the arg). However, the ppl is around 1500. Even when I pass self.train_data, self.eval_data or self.valid_data, it`s also around 1500.
- After training is done, when I call either test() or derive() and pass the arg --load_path, the self.shared.load_state_dict in load_model() throws out an error as "KeyError: unexpected key batch_norm.weight in state_dict". Moreover, I print the self.shared.state_dict.keys() and the content loaded by torch.load from the checkpoint, and find that, parameters stored in the checkpoint contain 4 parameters related with batch normalization as "batch_norm.weight", "batch_norm.bias", "bath_norm.running_mean", "batch_norm.running_var", while the parameters shown by self.shared.load_state_dict not.
It would be great if you could help check these.
Hi Ren Qian,
To address 1. and 2., it is necessary to re-train the model from scratch after discovering a good architecture. See Section 4.5 of this early revision of the ENAS paper: https://openreview.net/references/pdf?id=BykVS-WC-. This re-training is not implemented yet.
My results on the validation of the shared model are similar to yours, with the best validation perplexity I have seen being 82.
The third point you raised sounds like a genuine bug. I think it has to do with these lines in models/shared_rnn.py
:
ENAS-pytorch/models/shared_rnn.py
Lines 185 to 188 in 5736a24
Thank you for pointing out the issues.
@dukebw Hi Brendan, thanks for your response.
Yes, I just found that the derived final model has to be run from scratch and is not implemented yet. And I also find the reason for 3 and locate the problem in the code you list. As I remove this if-else branch, to add batch norm in any case, the reload method works well, of course. What is more interesting is that, when I do this, the test ppl is about 91, rather than 1500 as in point 2. This seems interesting as the paper says that removing batch norm when a fixed cell is sampled by the controller does not decrease the performance. Or the paper means that this should be done in the re-train from scratch after the best model is sampled? BTW, I am running your latest code now, but it seems to be much slower than the previous version?
Yes, I am sure what they meant was that the fixed cell, once derived by the controller, can be retrained without the need for batch normalization. In my view, the batch normalization is to address the "covariate shift" problem of the features as seen by the language model's decoder.
The latest code will be much slower, because I have set the default policy batch size to M = 10 (it was M = 1 before), as mentioned in this comment: https://openreview.net/forum?id=ByQZjx-0-¬eId=rJWmCYxyM. You can set the batch size using the --policy_batch_size
flag.
@dukebw If it works for M = 10, then I think it would be better to set M = 10. BTW, I run your latest code and the training ppl gets nan at epoch 56.
@dukebw : You say above: " This re-training is not implemented yet."
that was from March 14, has the retraining been implemented since then?
Also, re the last reply: it's not clear to me from reading the paper if that picking the best model of the samples and then retraining happens every epoch or at the end of all the epochs.