coreylowman/dfdx

Examples or resources for autodiff with 2 networks?

Cobrand opened this issue · 0 comments

I'm having a look at this example (because I'm trying to implement PPO as well), as well as others, and I can't figure out a way to have 2 inter-dependent networks for autodiff.

In the case of PPO, the formula is basically loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef (https://github.com/vwxyzjn/ppo-implementation-details/blob/main/ppo.py#L297).

Entropy loss and value loss I can handle just fine, but pg_loss depends on ratio of post-softmax probabilities, which comes from the policy network AND advantages, which depend on the value network. (Python impl: pg_loss = torch.max(-mb_advantages * ratio, -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)).mean() )

I have 2 separate networks for those, is there any way to have the autodiff work only regarding the Value network on one side, and only regarding the Policy network on the other side? I know that tensorflow can do that somehow, but I don't know about this lib.

Thanks.