pytorch/rl

[BUG] Issues with `TensorDictPrimer`

matteobettini opened this issue · 9 comments

Without the primer, the collector does not feed any hidden state to the policy

in the RNN tutorial it is stated that the primer is optional and it is used just to store the hidden states in the buffer.

This is not true in practice. Not adding the primer will result in the collector not feeding the hidden states to the policy during execution. Which will silently cause the rnn to loose any recurrency.

To reproduce, comment out this line

env.append_transform(lstm.make_tensordict_primer())

and print the policy input at this line

policy_output = self.policy(policy_input)

you will see that no hidden state is fed to the rnn during execution and no errors or warnings are thrown

The primer overwrites any nested spec

Consider an env with nested specs

 env = VmasEnv(
        scenario="balance,
        num_envs=5,
    )

add to it a primer for a nested hidden state

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
                "agents": CompositeSpec(
                    {
                        "h": UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                        )
                    },
                    shape=(*env.shape, env.n_agents),
                )
            }
        ),
    )

the primer code in

observation_spec[key] = self.primers[key] = spec.to(device)
will overwirite the observation spec instead of updating it, resulting in the loss of all the spec keys that previoulsy were in the "agents" spec

The same result is obtained with

    env = TransformedEnv(
        env,
        TensorDictPrimer(
            {
               ("agents","h"): UnboundedContinuousTensorSpec(
                            shape=(*env.shape, env.n_agents, 2, 128)
                )
            }
        ),
    )

here, updating the spec instead of overwriting it should do the job

The order of the primer in the transforms seems to have an impact

In the same vmas environemnt as above, if i put the primer and then the reward sum

 env = TransformedEnv(
        env,
        Compose(
            TensorDictPrimer(
                {
                    "agents": CompositeSpec(
                        {
                            "h": UnboundedContinuousTensorSpec(
                                shape=(*env.shape, env.n_agents, 2, 128)
                            )
                        },
                        shape=(*env.shape, env.n_agents),
                    )
                }
            ),
           RewardSum(
                        in_keys=[env.reward_key],
                        out_keys=[("agents", "episode_reward")],
                    ),
        ),
    )

all works well

but the opposite

 env = TransformedEnv(
        env,
        Compose(
            RewardSum(
                in_keys=[env.reward_key],
                out_keys=[("agents", "episode_reward")],
            ),
            TensorDictPrimer(
                {
                    "agents": CompositeSpec(
                        {
                            "h": UnboundedContinuousTensorSpec(
                                shape=(*env.shape, env.n_agents, 2, 128)
                            )
                        },
                        shape=(*env.shape, env.n_agents),
                    )
                }
            ),
        ),
    )

causes

Traceback (most recent call last):
  File "/Users/Matteo/PycharmProjects/torchrl/sota-implementations/multiagent/mappo_ippo.py", line 302, in train
    collector = SyncDataCollector(
                ^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 644, in __init__
    self._make_shuttle()
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 661, in _make_shuttle
    self._shuttle = self.env.reset()
                    ^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/common.py", line 2143, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 814, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 1129, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 4722, in _reset
    value = self.default_value[key]
            ~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('agents', 'episode_reward')

For the first issue, I think we should go for a solution where the Primer becomes optional and needed only if users want the hidden states in the collection buffer.

But without the primer users should still be able to use rnns in collectors with the logic that anything coming out of step_mdp is refed to the policy

Since there are multiple issues I'd suggest to open a tracker.
I'll comment on the first here: it's optional in the sense that you can make the env run without primer if no other module is involved. If a ParallelEnv or a collector is used things will indeed break.

Yeah I'll eventually spread them into separate issues.

If that is the case regarding the first, I suggest we make it extra clear that in the tutorial the Primer is not optional, as we are using a collector.

In general, do we really have no way to make the collector work without the primer? it would be nice to have it optional in collectors for users that do not want the hidden states as part of the output buffer

Adding to my previous comment, I think to solve the first issue we could add the output of the policy (looking at the next key) to the shuttle

policy_output = self.policy(policy_input)

I would really like to not use the Primer as it is a huge pain in large projects

I don't see a way of not having a primer, we need to let the env know about extra entries in the tensordict.
Is that such a "huge" pain though? We've worked hard with @albertbou92 to provide all the tooling to make this work as seamlessly as possible.

In the collector, we could automatically check if any primer is missing and append it. Or raise a warning. We can extract the expected primers from the actor.

I don't find it that inconvenient to use the primer. I simply got used to adding:

if primers := get_primers_from_module(actor):
    env.append_transform(primers)

However, if a user is not aware or does not remember to add the primer for some reason, silently not using recurrency can cause a lot of headaches.

What is the technical limitation that is preventing us from reading the hidden state from the policy output? It seems to me that since we are running the policy at collector init time, its outputs in the "next" tensordict could be captured and accounted for during the collector rollout (aka move them to the policy input at every step_mdp).

With this you would maybe loose the possibility of obtaining zeroed states after resets, but the absence of the hidden state (i.e. it being none) should be possible to handle in the rnn.

The pain I am referring to is that the Primer in the multiagent setting will require to know: the number of agents, the group names, the hidden sizes. All this information is not immediately available. Plus having it optional would make approach from new users easier IMO and less bug prone

Sorry to ask again, can we split this issue?
I'd like to close the pieces that are solved

Yeah we solved everything apart the first one, I ll make one for that