pytorch/rl

[BUG] `RenameTransform` of `ParallelEnv` is not the same as `ParallelEnv` of transformed environment

thomasbbrunner opened this issue · 0 comments

Describe the bug

In short:

transform(ParallelEnv(base_env)) != ParallelEnv(transform(base_env))

I'm aware that this is not supported in some cases, but I'd expect that this would work for the RenameTransform.

This is even stated in the documentation: "There are two equivalent ways of transforming parallen environments: in each process separately, or on the main process. It is even possible to do both."

To Reproduce

Simple script to reproduce the issue:

from torchrl.envs import RenameTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs import check_env_specs, ParallelEnv

def _make_env():
    return GymEnv("CartPole-v1")

def _transform_env(env):
    return TransformedEnv(
        env,
        RenameTransform(
            in_keys=[
                "terminated",
            ],
            out_keys=[
                ("stuff", "terminated"),
            ],
        )
    )


def _make_transformed_env():
    return _transform_env(_make_env())

if __name__ == "__main__":

    base_env = _make_env()
    transformed_env = _make_transformed_env()
    trans_parallel_env = _transform_env(ParallelEnv(
            1,
            create_env_fn=_make_env,
        )
    )
    parallel_trans_env = ParallelEnv(
        1,
        create_env_fn=_make_transformed_env,
    )

    # Works!
    check_env_specs(base_env)
    # Works!
    check_env_specs(transformed_env)
    # Works!
    check_env_specs(trans_parallel_env)
    # RuntimeError: Cannot infer the value of terminated when only done and truncated are present.
    check_env_specs(parallel_trans_env)

Expected behavior

The script above should run without errors.

System info

Ubuntu 22.04
Python 3.10.14
torch 2.4.1
torchrl 0.5.0

Reason and Possible fixes

I've tracked down the issue to the _set_properties in the BatchedEnvBase class. When writing to the self.done_spec property, the EnvBase.done_spec setter does not respect the renamed keys.

Here's a comparison of the full_done_spec before and after calling the done_spec setter:

Before: Composite(
    done: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=cpu,
        shape=torch.Size([1])),
    device=cpu,
    shape=torch.Size([1]))
After: Composite(
    done: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        done: Categorical(
            shape=torch.Size([1, 1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=cpu,
        shape=torch.Size([1])),
    terminated: Categorical(
        shape=torch.Size([1, 1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    device=cpu,
    shape=torch.Size([1]))

The correct spec should be:

>>> transformed_env.full_done_spec
Composite(
    done: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    truncated: Categorical(
        shape=torch.Size([1]),
        space=CategoricalBox(n=2),
        device=cpu,
        dtype=torch.bool,
        domain=discrete),
    stuff: Composite(
        terminated: Categorical(
            shape=torch.Size([1]),
            space=CategoricalBox(n=2),
            device=cpu,
            dtype=torch.bool,
            domain=discrete),
        device=None,
        shape=torch.Size([])),
    device=None,
    shape=torch.Size([]))

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)