opendilab/LightZero

`SampledEfficientZeroModel` does not pass `lstm_hidden_size` through `DynamicsNetwork`

ekiefl opened this issue · 2 comments

In the __init__ method of SampledEfficientZeroModel, the argument lstm_hidden_size should be passed through the DynamicsNetwork. Otherwise the DynamicsNetwork will have the default lstm_hidden_size of 512, rather than what the user specifies. This leads to the following error:

Traceback (most recent call last):
  File "zoo/pooltool/sum_to_three/config/sum_to_three_image_config.py", line 122, in <module>
    train_muzero(
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/entry/train_muzero.py", line 160, in train_muzero
    new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/worker/muzero_collector.py", line 411, in collect
    policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon)
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/policy/sampled_efficientzero.py", line 867, in _forward_collect
    self._mcts_collect.search(
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/mcts/tree_search/mcts_ctree_sampled.py", line 162, in search
    network_output = model.recurrent_inference(
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/sampled_efficientzero_model.py", line 305, in recurrent_inference
    next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action)
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/sampled_efficientzero_model.py", line 427, in _dynamics
    next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network(
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/evan/Software/pooltool_ml/LightZero/lzero/model/efficientzero_model.py", line 576, in forward
    value_prefix, next_reward_hidden_state = self.lstm(x, reward_hidden_state)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 875, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 791, in check_forward_args
    self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
  File "/Users/evan/anaconda3/envs/pooltool_ml/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 256, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (1, 6, 512), got [1, 6, 128]

Hello, thank you for your feedback. We have identified this bug and it has now been fixed in the latest commit 3823560. Best wishes!

Thank you @puyuan1996. You guys at the LightZero team are so quick and helpful.