google-deepmind/rlax

Writing a MPO example (help I'm confused)

act65 opened this issue · 1 comments

act65 commented

I'm trying to write an example for MPO (for a categorical action space).
However, I'm confused.

Mainly I'm confused about the kl_constraints arg to mpo_loss.

kl_constraints = [(rlax.categorical_kl_divergence(???, ???), lagrange_penalty)]

I dont understand what the two args to the kl div would be.
(also. I dont understand why it's a list. How can there be more than one kl div?)


Afaik, this KL constraint is to be used for the M step.
So should be doing something like;

$$ J(\theta) = ... + KL(π(a|s, θ_i), π(a|s, θ)) $$

However, this equation also doesnt make sense to me.
Arent we evaluating the gradient of $J$ at $\theta_i$, so the KL term would be 0?

What am I missing...? (something important it seems.)

act65 commented

nvm...
i can just ready your tests