Don't pass envpool envs where vectorenvs are needed
MischaPanch opened this issue · 0 comments
MischaPanch commented
See the block comments in test and in Collector
method. Somewhere a pure envpool-env is passed instead of instances of BaseVectorEnv
, thus the interface is not followed.
This means we rely on the two interfaces accidentally kind-of coinciding. They already don't fully coincide since envpool envs return an info as single dict with arrays, whereas tianshou's VectorEnv
s return an array of dicts.
@Trinkle23897 this issue might be of interest to you
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
num_envs = 4
env = VectorEnvNormObs(
envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True),
)
obs, info = env.reset()
assert obs.shape[0] == num_envs
# This is not actually unreachable b/c envpool does not return info in the right format
if isinstance(info, dict): # type: ignore[unreachable]
for _, v in info.items(): # type: ignore[unreachable]
if not isinstance(v, dict):
assert v.shape[0] == num_envs
else:
for _info in info:
for _, v in _info.items():
if not isinstance(v, dict):
assert v.shape[0] == num_envs
def reset_env(
self,
gym_reset_kwargs: dict[str, Any] | None = None,
) -> None:
"""Reset the environments and the initial obs, info, and hidden state of the collector."""
gym_reset_kwargs = gym_reset_kwargs or {}
self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs)
# TODO: hack, wrap envpool envs such that they don't return a dict
if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable]
# this can happen if the env is an envpool env. Then the thing returned by reset is a dict
# with array entries instead of an array of dicts
# We use Batch to turn it into an array of dicts
self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable]
self._pre_collect_hidden_state_RH = None