pytorch/rl

[Feature Request] Reward-to-go

vmoens opened this issue · 7 comments

Implement reward-to-go (as in here)

@vmoens Is this issue still open? @BY571 Does #1038 close it?

Yes this is closed now!

Nice! @vmoens Do you know if this is meant to handle truncation correctly (e.g. as introduced in Gymnasium)? I.e. allows to give a value for the truncation state, $V(s_{truncation})$, to bootstrap all the "reward-to-go", $G_t$, in that episode from it.
Or maybe, if this is deferred to the learning logic, there is a flag for states belonging to truncated episodes so that one could bootstrap their values (e.g. $\nabla \log \pi(a_t|s_t)[(G_t + (1_{s_t \in truncatedEpisode} * V(s_{truncation})) ...]$).
I'm not aware of the updates since #403.

Similarly, for episodes that have not finished yet (typically the last episode in each concurrent environment), is there a way to find those and mask them out in the loss?
Thanks!

I think that -- provided that you pass the correct mask to the function -- truncation should be handled properly.
@BY571 can you confirm?

This is the tranform. It looks at done or truncated here.
The functional handles these two as a "done" but as you can see, upstream the transform will do done = done | truncated.

Let us know if something is not clear!

BY571 commented

Yes, when an episode was ended (without done=True) truncated is set true on that last state the transform handles it as if that was the last state of the episode:

>>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[4.],
        [3.],
        [2.],
        [1.]])

If you want to mask these episodes out completely you might have to set states, reward, actions (etc) to zero. Simply setting truncated to True for all those steps would not work. Then the reward-to-go transform returns only the current reward per step as it expects that each step is a single episode with length=1:

>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-3:]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.],
        [1.],
        [1.],
        [1.]])
      
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][-3:]=True
>>> td["reward"][-3:]=0
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[1.],
        [0.],
        [0.],
        [0.]])

Let me know if this helped to clarify

Thanks both for the clarifications.

So I guess the answer to my question is that the transform is aware of truncation and handles it as termination.
So, it will not bootstrap truncated episodes or mask unfinished ones. This is left to the user.

Also, it does not expect the last state to have ["next"]["truncated”] = True or ["next"][“done”] = True, it will only complain if there is not any done or truncated in the batch.

>>> from torchrl.envs.transforms import Reward2GoTransform
>>> import torch
>>> from tensordict import TensorDict
>>> r2g = Reward2GoTransform(in_keys=["reward"], out_keys=["reward_to_go"])
>>> td = TensorDict({"reward": torch.ones(4,1), "next": {"done":torch.zeros(4,1).to(dtype=bool), "truncated": torch.zeros(4,1).to(dtype=bool)}}, batch_size=())
>>> td["next"]["truncated"][1]=True
>>> r2g._inv_call(td)["reward_to_go"]
tensor([[2.],	# Belongs to truncated episode.
        [1.],	# Belongs to truncated episode. Next is truncated.
        [2.],	# Belongs to unfinished episode.
        [1.]])	# Belongs to unfinished episode. Next is not truncated, nor done.

Otherwise, regarding the last step truncation:

Yes, when an episode was ended (without done=True) truncated is set true

Where does this happen? It doesn't seem to be done by a collector at the last frame of a batch. (I’m new to TorchRL and in the process of deciding whether I should adopt it!)