
Won't replay RNN policies using script.visualize

NikEyX opened this issue · 5 comments

File "/home/rl-starter-files/utils/", line 24, in __init__ self.acmodel.load_state_dict(utils.get_model_state(model_dir))
  File "/home/miniconda3/envs/ml/lib/python3.7/site-packages/torch/nn/modules/", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ACModel:

Unexpected key(s) in state_dict: "memory_rnn.weight_ih", "memory_rnn.weight_hh", "memory_rnn.bias_ih", "memory_rnn.bias_hh". 

Please could you explain me the context? Which command did you execute, with which model? It seems the error is coming from something you didn't do well.

Currently, I can't reproduce.

Hi, yes of course, here are the commands I'm running to reproduce the above error:

python3 -m scripts.train --algo ppo --env MiniGrid-RedBlueDoors-6x6-v0 --model RedBlueDoors3 --recurrence 4 --save-interval 10 --frames 100000

This seems to end up training and saving a model (despite 100k frames being not enough obviously to solve it):

Status saved
U 41 | F 083968 | FPS 1634 | D 52 | rR:μσmM 0.09 0.17 0.00 0.54 | F:μσmM 426.4 254.2 8.0 720.0 | H 1.941 | V 0.022 | pL 0.004 | vL 0.000 | ∇ 0.002
U 42 | F 086016 | FPS 1672 | D 53 | rR:μσmM 0.01 0.03 0.00 0.14 | F:μσmM 419.7 275.1 8.0 720.0 | H 1.944 | V 0.018 | pL 0.004 | vL 0.000 | ∇ 0.001
U 43 | F 088064 | FPS 1604 | D 54 | rR:μσmM 0.05 0.15 0.00 0.61 | F:μσmM 461.2 240.0 132.0 720.0 | H 1.941 | V 0.016 | pL -0.002 | vL 0.001 | ∇ 0.009
U 44 | F 090112 | FPS 1638 | D 56 | rR:μσmM 0.16 0.28 0.00 0.71 | F:μσmM 467.1 238.4 132.0 720.0 | H 1.929 | V 0.039 | pL -0.010 | vL 0.004 | ∇ 0.025
U 45 | F 092160 | FPS 1681 | D 57 | rR:μσmM 0.16 0.28 0.00 0.71 | F:μσmM 457.9 241.8 132.0 720.0 | H 1.937 | V 0.010 | pL 0.002 | vL 0.000 | ∇ 0.005
U 46 | F 094208 | FPS 1632 | D 58 | rR:μσmM 0.21 0.29 0.00 0.71 | F:μσmM 410.9 214.9 139.0 720.0 | H 1.935 | V 0.016 | pL -0.004 | vL 0.001 | ∇ 0.007
U 47 | F 096256 | FPS 1620 | D 59 | rR:μσmM 0.09 0.21 0.00 0.67 | F:μσmM 331.4 236.7 36.0 720.0 | H 1.935 | V 0.018 | pL -0.002 | vL 0.002 | ∇ 0.010
U 48 | F 098304 | FPS 1608 | D 61 | rR:μσmM 0.10 0.24 0.00 0.79 | F:μσmM 318.9 274.2 17.0 720.0 | H 1.936 | V 0.030 | pL -0.003 | vL 0.003 | ∇ 0.014
U 49 | F 100352 | FPS 1621 | D 62 | rR:μσmM 0.11 0.27 0.00 0.85 | F:μσmM 343.2 282.0 17.0 720.0 | H 1.929 | V 0.030 | pL -0.001 | vL 0.003 | ∇ 0.011

I then want to visualize the resulting model and run:
python3 -m scripts.visualize --env MiniGrid-RedBlueDoors-6x6-v0 --model RedBlueDoors3

which results in

Traceback (most recent call last):
  File "/home/xxx/apps/miniconda3/envs/ml/lib/python3.7/", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/xxx/apps/miniconda3/envs/ml/lib/python3.7/", line 85, in _run_code
    exec(code, run_globals)
  File "/home/xxx/projects/rl-starter-files/scripts/", line 47, in <module>
    agent = utils.Agent(env.observation_space, env.action_space, model_dir, device, args.argmax)
  File "/home/xxx/projects/rl-starter-files/utils/", line 24, in __init__
  File "/home/xxx/apps/miniconda3/envs/ml/lib/python3.7/site-packages/torch/nn/modules/", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ACModel:
	Unexpected key(s) in state_dict: "memory_rnn.weight_ih", "memory_rnn.weight_hh", "memory_rnn.bias_ih", "memory_rnn.bias_hh". 

It does not seem like there is a command line option to get the visualizer to use a model with use_memory=True, or is there something I'm missing? Can you replicate the issue with this?

Thank you for giving me these details! Indeed, it seems there is an issue in the library. I will try to fix it this week end.

This had just been fixed in this commit 6e717b8 thanks to @AMairesse

To make evaluate/visualize a model with memory or text, --memory or --text should be added to the command, e.g.:

python3 -m scripts.visualize --env MiniGrid-RedBlueDoors-6x6-v0 --model RedBlueDoors3 --memory

Running the example found in the README python3 -m scripts.train --algo ppo --env MiniGrid-RedBlueDoors-6x6-v0 --model RedBlueDoors --recurrence 4 --save-interval 10 --frames 1000000 reproduces this error