Writing a MPO example (help I'm confused)
act65 opened this issue · 1 comments
act65 commented
I'm trying to write an example for MPO (for a categorical action space).
However, I'm confused.
Mainly I'm confused about the kl_constraints
arg to mpo_loss
.
kl_constraints = [(rlax.categorical_kl_divergence(???, ???), lagrange_penalty)]
I dont understand what the two args to the kl div would be.
(also. I dont understand why it's a list. How can there be more than one kl div?)
Afaik, this KL constraint is to be used for the M step.
So should be doing something like;
However, this equation also doesnt make sense to me.
Arent we evaluating the gradient of
What am I missing...? (something important it seems.)
act65 commented
nvm...
i can just ready your tests