Khrylx/PyTorch-RL

TRPO: KL Divergence Computation

sandeepnRES opened this issue · 1 comments

I see how KL divergence is computed here:
def get_kl(self, x): action_prob1 = self.forward(x) action_prob0 = action_prob1.detach() kl = action_prob0 * (torch.log(action_prob0) - torch.log(action_prob1)) return kl.sum(1, keepdim=True)

Isn't this wrong? shouldn't the KL divergence be computed for new policy and old policy? Right now it seems the action_prob1, action_prob0 are same, so KL divergence will always be zero, isn't it?

I'm not sure what's the problem. Before the update, the new policy is equal to the old policy, so the KL is zero. Actually, the first derivative of the KL is also zero, because the KL reaches the minimum when new policy equals the old policy. But what we care about is the second derivative (Hessian) of the KL, which is not zero.