vwxyzjn/cleanrl

Potential bug in PPO+RND?

Closed this issue · 2 comments

In ppo_rnd_envpool.py why is line 368:

predict_next_feature = rnd_model.predictor(rnd_next_obs)

not under the torch.no_grad() context?

Like the policy, the RND network should compute no gradients during rollout collection. I could be wrong though, I just wanted to make sure.

Thanks!

Hi Roger, the predict_next_feature itself is indeed being tracked, but it's only been used in curiosity_rewards which only uses .data of the variable, it is equivalent to requires_grad=False or torch.no_grad() which means we do not collect grad during the rollout. Does it answer your question?