[BUG] `RenameTransform` of `ParallelEnv` is not the same as `ParallelEnv` of transformed environment
thomasbbrunner opened this issue · 0 comments
Describe the bug
In short:
transform(ParallelEnv(base_env)) != ParallelEnv(transform(base_env))
I'm aware that this is not supported in some cases, but I'd expect that this would work for the RenameTransform
.
This is even stated in the documentation: "There are two equivalent ways of transforming parallen environments: in each process separately, or on the main process. It is even possible to do both."
To Reproduce
Simple script to reproduce the issue:
from torchrl.envs import RenameTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs import check_env_specs, ParallelEnv
def _make_env():
return GymEnv("CartPole-v1")
def _transform_env(env):
return TransformedEnv(
env,
RenameTransform(
in_keys=[
"terminated",
],
out_keys=[
("stuff", "terminated"),
],
)
)
def _make_transformed_env():
return _transform_env(_make_env())
if __name__ == "__main__":
base_env = _make_env()
transformed_env = _make_transformed_env()
trans_parallel_env = _transform_env(ParallelEnv(
1,
create_env_fn=_make_env,
)
)
parallel_trans_env = ParallelEnv(
1,
create_env_fn=_make_transformed_env,
)
# Works!
check_env_specs(base_env)
# Works!
check_env_specs(transformed_env)
# Works!
check_env_specs(trans_parallel_env)
# RuntimeError: Cannot infer the value of terminated when only done and truncated are present.
check_env_specs(parallel_trans_env)
Expected behavior
The script above should run without errors.
System info
Ubuntu 22.04
Python 3.10.14
torch 2.4.1
torchrl 0.5.0
Reason and Possible fixes
I've tracked down the issue to the _set_properties
in the BatchedEnvBase
class. When writing to the self.done_spec
property, the EnvBase.done_spec
setter does not respect the renamed keys.
Here's a comparison of the full_done_spec
before and after calling the done_spec
setter:
Before: Composite(
done: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
truncated: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
stuff: Composite(
terminated: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
device=cpu,
shape=torch.Size([1])),
device=cpu,
shape=torch.Size([1]))
After: Composite(
done: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
truncated: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
stuff: Composite(
terminated: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
done: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
device=cpu,
shape=torch.Size([1])),
terminated: Categorical(
shape=torch.Size([1, 1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
device=cpu,
shape=torch.Size([1]))
The correct spec should be:
>>> transformed_env.full_done_spec
Composite(
done: Categorical(
shape=torch.Size([1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
truncated: Categorical(
shape=torch.Size([1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
stuff: Composite(
terminated: Categorical(
shape=torch.Size([1]),
space=CategoricalBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
device=None,
shape=torch.Size([])),
device=None,
shape=torch.Size([]))
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)