thu-ml/tianshou

Don't pass envpool envs where vectorenvs are needed

MischaPanch opened this issue · 0 comments

See the block comments in test and in Collector method. Somewhere a pure envpool-env is passed instead of instances of BaseVectorEnv, thus the interface is not followed.

This means we rely on the two interfaces accidentally kind-of coinciding. They already don't fully coincide since envpool envs return an info as single dict with arrays, whereas tianshou's VectorEnvs return an array of dicts.

@Trinkle23897 this issue might be of interest to you

@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
    num_envs = 4
    env = VectorEnvNormObs(
        envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True),
    )
    obs, info = env.reset()
    assert obs.shape[0] == num_envs
    # This is not actually unreachable b/c envpool does not return info in the right format
    if isinstance(info, dict):  # type: ignore[unreachable]
        for _, v in info.items():  # type: ignore[unreachable]
            if not isinstance(v, dict):
                assert v.shape[0] == num_envs
    else:
        for _info in info:
            for _, v in _info.items():
                if not isinstance(v, dict):
                    assert v.shape[0] == num_envs
    def reset_env(
        self,
        gym_reset_kwargs: dict[str, Any] | None = None,
    ) -> None:
        """Reset the environments and the initial obs, info, and hidden state of the collector."""
        gym_reset_kwargs = gym_reset_kwargs or {}
        self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs)
        # TODO: hack, wrap envpool envs such that they don't return a dict
        if isinstance(self._pre_collect_info_R, dict):  # type: ignore[unreachable]
            # this can happen if the env is an envpool env. Then the thing returned by reset is a dict
            # with array entries instead of an array of dicts
            # We use Batch to turn it into an array of dicts
            self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R)  # type: ignore[unreachable]

        self._pre_collect_hidden_state_RH = None