[BUG] Issues with `TensorDictPrimer`
matteobettini opened this issue · 9 comments
Without the primer, the collector does not feed any hidden state to the policy
in the RNN tutorial it is stated that the primer is optional and it is used just to store the hidden states in the buffer.
This is not true in practice. Not adding the primer will result in the collector not feeding the hidden states to the policy during execution. Which will silently cause the rnn to loose any recurrency.
To reproduce, comment out this line
rl/tutorials/sphinx-tutorials/dqn_with_rnn.py
Line 269 in 0063741
and print the policy input at this line
rl/torchrl/collectors/collectors.py
Line 733 in 0063741
you will see that no hidden state is fed to the rnn during execution and no errors or warnings are thrown
The primer overwrites any nested spec
Consider an env with nested specs
env = VmasEnv(
scenario="balance,
num_envs=5,
)
add to it a primer for a nested hidden state
env = TransformedEnv(
env,
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
)
the primer code in
rl/torchrl/envs/transforms/transforms.py
Line 4649 in 0063741
The same result is obtained with
env = TransformedEnv(
env,
TensorDictPrimer(
{
("agents","h"): UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
}
),
)
here, updating the spec instead of overwriting it should do the job
The order of the primer in the transforms seems to have an impact
In the same vmas environemnt as above, if i put the primer and then the reward sum
env = TransformedEnv(
env,
Compose(
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
RewardSum(
in_keys=[env.reward_key],
out_keys=[("agents", "episode_reward")],
),
),
)
all works well
but the opposite
env = TransformedEnv(
env,
Compose(
RewardSum(
in_keys=[env.reward_key],
out_keys=[("agents", "episode_reward")],
),
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
),
)
causes
Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/sota-implementations/multiagent/mappo_ippo.py", line 302, in train
collector = SyncDataCollector(
^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 644, in __init__
self._make_shuttle()
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 661, in _make_shuttle
self._shuttle = self.env.reset()
^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/common.py", line 2143, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 814, in _reset
tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 1129, in _reset
tensordict_reset = t._reset(tensordict, tensordict_reset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 4722, in _reset
value = self.default_value[key]
~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('agents', 'episode_reward')
For the first issue, I think we should go for a solution where the Primer becomes optional and needed only if users want the hidden states in the collection buffer.
But without the primer users should still be able to use rnns in collectors with the logic that anything coming out of step_mdp is refed to the policy
Since there are multiple issues I'd suggest to open a tracker.
I'll comment on the first here: it's optional in the sense that you can make the env run without primer if no other module is involved. If a ParallelEnv or a collector is used things will indeed break.
Yeah I'll eventually spread them into separate issues.
If that is the case regarding the first, I suggest we make it extra clear that in the tutorial the Primer is not optional, as we are using a collector.
In general, do we really have no way to make the collector work without the primer? it would be nice to have it optional in collectors for users that do not want the hidden states as part of the output buffer
Adding to my previous comment, I think to solve the first issue we could add the output of the policy (looking at the next
key) to the shuttle
rl/torchrl/collectors/collectors.py
Line 733 in 371181c
I would really like to not use the Primer as it is a huge pain in large projects
I don't see a way of not having a primer, we need to let the env know about extra entries in the tensordict.
Is that such a "huge" pain though? We've worked hard with @albertbou92 to provide all the tooling to make this work as seamlessly as possible.
In the collector, we could automatically check if any primer is missing and append it. Or raise a warning. We can extract the expected primers from the actor.
I don't find it that inconvenient to use the primer. I simply got used to adding:
if primers := get_primers_from_module(actor):
env.append_transform(primers)
However, if a user is not aware or does not remember to add the primer for some reason, silently not using recurrency can cause a lot of headaches.
What is the technical limitation that is preventing us from reading the hidden state from the policy output? It seems to me that since we are running the policy at collector init time, its outputs in the "next" tensordict could be captured and accounted for during the collector rollout (aka move them to the policy input at every step_mdp).
With this you would maybe loose the possibility of obtaining zeroed states after resets, but the absence of the hidden state (i.e. it being none) should be possible to handle in the rnn.
The pain I am referring to is that the Primer in the multiagent setting will require to know: the number of agents, the group names, the hidden sizes. All this information is not immediately available. Plus having it optional would make approach from new users easier IMO and less bug prone
Sorry to ask again, can we split this issue?
I'd like to close the pieces that are solved
Yeah we solved everything apart the first one, I ll make one for that