[BUG] `_complete_done` always sets missing terminated to `False`
jkrude opened this issue · 0 comments
Describe the bug
For a custom environment that only outputs "done" keys in _step
, as in the example, the automatically added "terminated" key will always be False across all dimensions and not mirror the done-tensor.
This happens regardless of whether the user defines the done spec or it automatically added in EnvBase. _create_done_specs
.
To Reproduce
I tried to keep the example to a minimum.
The important part is in _step
where we add the "done" key as torch.tensor([True, False)]
.
from typing import Optional
import torch
from tensordict import TensorDictBase, TensorDict
from torchrl.data import (
CompositeSpec,
BinaryDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
OneHotDiscreteTensorSpec,
)
from torchrl.envs import EnvBase
class CustomEnv(EnvBase):
def __init__(
self,
*,
device=None,
batch_size: Optional[torch.Size] = torch.Size([2]),
run_type_checks: bool = False,
allow_done_after_reset: bool = False,
):
assert batch_size == (2,) # hardcoded for minimal example
super().__init__(
device=device,
batch_size=batch_size,
run_type_checks=run_type_checks,
allow_done_after_reset=allow_done_after_reset,
)
self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec(
shape=torch.Size(batch_size + (1,))
),
shape=batch_size,
)
self.action_spec = OneHotDiscreteTensorSpec(n=2, shape=batch_size)
self.reward_spec: BinaryDiscreteTensorSpec = BinaryDiscreteTensorSpec(
n=1, dtype=torch.int8, shape=torch.Size([2, 1])
)
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
done = torch.tensor([True, False], dtype=torch.bool)
next_observation = torch.randn(self.observation_spec["observation"].shape)
return TensorDict(
{"observation": next_observation, "done": done, "reward": torch.ones((2,))},
batch_size=(2,),
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return TensorDict(
{"observation": torch.randn(self.observation_spec["observation"].shape)},
batch_size=(2,),
)
def _set_seed(self, seed: Optional[int]):
pass
env = CustomEnv()
td = env.reset()
env.rand_action(td)
env.step(td)
assert env.done_keys == ["done", "terminated"]
assert torch.equal(td[("next", "done")],torch.tensor([[True], [False]]))
assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))
terminated_env_bug.py:62 (test_)
def test_():
env = CustomEnv()
td = env.reset()
env.rand_action(td)
env.step(td)
assert env.done_keys == ["done", "terminated"]
assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]]))
> assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]]))
E assert False
E + where False = <built-in method equal of type object at 0x7fabbdc64800>(tensor([[False],\n [False]]), tensor([[ True],\n [False]]))
E + where <built-in method equal of type object at 0x7fabbdc64800> = torch.equal
E + and tensor([[ True],\n [False]]) = <built-in method tensor of type object at 0x7fabbdc64800>([[True], [False]])
E + where <built-in method tensor of type object at 0x7fabbdc64800> = torch.tensor
Expected behavior
The ("next", "terminated") entry is equal to the ("next", "done") entry as documented for EnvBase._complete_done
.
System info
Using torchrl-nightly installed with pip. But should also apply for the main-branch as relevant code is the same.
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.7.3 2.0.0 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] linux
Reason and Possible fixes
The problem seems to come from EnvBase._complete_done
.
Thefor key, item in done_spec.items(False, True):
loop in Line 1509 iterates both over the "done" and the "terminated" key, however only for "done" a value is present in data
.
For the "done" key (happening first) data.set("terminated", val)
is set correctly to the values of data["done"]
in line 1537.
But then for the "terminated" key the elif val is None:
is triggered and data["terminated"]
is overridden again.
...
for key, item in done_spec.items(False, True): # goes over done and terminated (order is important)
val = vals.get(key, None) # will be [[True], [False]] for "done" but None for "terminated"
if (
key == "done"
and val is not None
and "terminated" in done_spec_keys
and "terminated" not in data_keys
):
if "truncated" in data_keys:
raise RuntimeError(
"Cannot infer the value of terminated when only done and truncated are present."
)
data.set("terminated", val)
elif (
key == "terminated"
and val is not None
and "done" in done_spec_keys
and "done" not in data_keys
):
if "truncated" in data_keys:
done = val | data.get("truncated")
data.set("done", done)
else:
data.set("done", val)
elif val is None:
# we must keep this here: we only want to fill with 0s if we're sure
# done should not be copied to terminated or terminated to done
# in this case, just fill with 0s
data.set(key, item.zero(leading_dim)) # overrides the "terminated" key with False
return data
cc @kurtamohler