`SampledEfficientZeroModel` does not pass `lstm_hidden_size` through `DynamicsNetwork`
ekiefl opened this issue · 2 comments
ekiefl commented
In the __init__
method of SampledEfficientZeroModel
, the argument lstm_hidden_size
should be passed through the DynamicsNetwork
. Otherwise the DynamicsNetwork
will have the default lstm_hidden_size
of 512, rather than what the user specifies. This leads to the following error:
Traceback (most recent call last):
File "zoo/pooltool/sum_to_three/config/sum_to_three_image_config.py", line 122, in <module>
train_muzero(
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/entry/train_muzero.py", line 160, in train_muzero
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/worker/muzero_collector.py", line 411, in collect
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon)
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/policy/sampled_efficientzero.py", line 867, in _forward_collect
self._mcts_collect.search(
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/mcts/tree_search/mcts_ctree_sampled.py", line 162, in search
network_output = model.recurrent_inference(
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/sampled_efficientzero_model.py", line 305, in recurrent_inference
next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action)
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/sampled_efficientzero_model.py", line 427, in _dynamics
next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network(
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/efficientzero_model.py", line 576, in forward
value_prefix, next_reward_hidden_state = self.lstm(x, reward_hidden_state)
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 875, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 791, in check_forward_args
self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 256, in check_hidden_size
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (1, 6, 512), got [1, 6, 128]
puyuan1996 commented
Hello, thank you for your feedback. We have identified this bug and it has now been fixed in the latest commit 3823560. Best wishes!
ekiefl commented
Thank you @puyuan1996. You guys at the LightZero team are so quick and helpful.