ikostrikov/pytorch-a2c-ppo-acktr-gail

why PPO needs to store action_log_probs instead of using stop_gradient for better efficiency?

Opened this issue · 1 comments

Hi,
I am looking at the PPO implementation, and I am curious about this part (actually many other implementations are using this workflow as well, so I am also curious to see if I miss anything)

So the action_log_probs is created, removed gradient (by setting requires_gradient=False), and inserted into the storage buffer, this action_log_probs is generated by the following function and then will be referred as old_action_log_probs_batch in PPO

def act(self, inputs, rnn_hxs, masks, deterministic=False):
        ...
        action_log_probs = dist.log_probs(action)

        return value, action, action_log_probs, rnn_hxs

In PPO algorithm, the ratio is calculated by the following, the action_log_probs is from evaluate_actions()

values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
                    obs_batch, recurrent_hidden_states_batch, masks_batch,
                    actions_batch)
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

If I am not understanding wrong, evaluate_actions() and act() will output the same action_log_probs because they are using the same actor_critic and calling log_probs(action), the only difference is the old_action_log_probs_batch has the gradient removed, so backpropagation will not go through it.

So my question is, why we bother to save old_action_log_probs_batch in the storage, but instead, something like this can be created on the fly.

values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
                    obs_batch, recurrent_hidden_states_batch, masks_batch,
                    actions_batch)
old_action_log_probs_batch = action_log_probs.detach()
ratio = torch.exp(action_log_probs - old_action_log_probs_batch)

Thank you for your attention. Look forward to the discussion.

Regards,
Tian

In my understanding, the key point is that after sampling trajectories, the agent parameters would be updated several times (it's up to args.ppo_epoch). At the first updating time, the situation is as you said. However, since the second time, the old_action_log_probs in the PPO implementation is calculated based on the original paramenters, while old_action_log_probs in your implementation is calculated based on paramenters that have been updated once.