araffin/sbx

[Feature Request] Recurrent policies

jamesheald opened this issue · 16 comments

There are recurrent (LSTM) policy options for sb3 (e.g. RecurrentPPO). It would be great to have recurrent PPO implemented for sbx.

Hello,
are you willing to contribute the implementation?

Yes, definitely. I'm new to the stable baselines environment and am currently implementing some more basic things for what I need, but once I'm at the stage of my project where it would be useful to have the recurrent version of PPO, I would be happy to implement this.

Hello, I'd like to contribute to the implementation if possible.

Would you like to share some code with the current implementation of PPO ? Or add a new reccurentppo dir in https://github.com/araffin/sbx/tree/master/sbx ?

Probably a new folder would be cleaner.

Ok I'll try that !

Hello, thanks for asking because I indeed had a few questions about this implementation. I first started by looking at the LSTM section of this blog about PPO in order to better understand the specifities of the algorithm. In this implementation, the author adds an lstm component to the agent without changing the actor and the critic networks, and only saves the first lstm state at the beginning of a rollout. Then he uses it to reconstruct the probability distributions (used in rollouts) during the networks updates (see ppo_atari_lstm.py from CleanRL).

Then I looked at the implementation of LSTM_PPO in Sb3-Contrib which does something slightly different for the networks of the agent. Here it is the actor and the critic that both incorporate an LSTM component (see sb3_contrib/common/recurrent/policies.py).

But the major difference is that all the lstm states obtained during the rollouts are added to a buffer (see sb3_contrib/ppo_recurrent/ppo_recurrent.py). This buffer seems to implements a mechanism for padding sequences with a mask, which ensures that episodes of varying lengths are padded to the same one within the buffer, while still indicating where an episode ends. The mask is then used in the train function in order to take account of episode endings when doing the updates of the network. But otherwise these updates seem to be pretty similar to the ones in vanilla PPO from Sb3. For the implementation of the rollout buffer in jax, there was this issue that told it might be easier to handle rollout data sequentially (instead of using the mask and padding mechanisms) and just jit the whole function (which should work because the rollout data will always have the same shape).

Additionally, I am not sure to understand this part of the file :

single_hidden_state_shape = (lstm.num_layers, self.n_envs, lstm.hidden_size)
# hidden and cell states for actor and critic
self._last_lstm_states = RNNStates(
    (
        th.zeros(single_hidden_state_shape, device=self.device),
        th.zeros(single_hidden_state_shape, device=self.device),
    ),
    (
        th.zeros(single_hidden_state_shape, device=self.device),
        th.zeros(single_hidden_state_shape, device=self.device),
    ),
)

Why both hidden and cell state are of shape (2, single_hidden_state_shape) ? Compared to CleanRL implementation where the shape for each is only (single_hidden_state_shape,) ? Apart from that, am I missing something in this implementation of LSTM PPO (eg gradient being computed in a different manner here compared to vanilla PPO)?

Sorry I didn't think this comment was going to be that long ...
But from what I understand I should surely try to do something in the spirit of Sb3-Contrib (but maybe with a simpler implementation of the buffer ?). What do you think @araffin @jamesheald ?

@corentinlger sorry I was until today at the RL conference, let me try to answer in the coming days when I'm back ;)

In short: recurrent PPO in SB3 contrib is overly complex (and I'm not happy about it, so I would be glad if we can find a cleaner solution).

Hello, no problem ! I'll also try to think about an simpler solution (at least for the first minimal implementation)