tristandeleu/pytorch-maml-rl

KL divergence with old policy in trpo training

Closed this issue · 2 comments

Hi,

I noticed that during trpo training, we need to compute hessian vector product to get the step direction.

The code for hessian vector product requires computing kl divergence between new and old policy

def hessian_vector_product(self, episodes, damping=1e-2):
    def _product(vector):
        kl = self.kl_divergence(episodes)
        ...... # use kl to do some computations
    return _product

Below is the implementation of self.kl_divergence()

def kl_divergence(self, episodes, old_pis=None):
    kls = []
    if old_pis is None:
        old_pis = [None] * len(episodes)

    for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis):
        params = self.adapt(train_episodes)
        pi = self.policy(valid_episodes.observations, params=params)

        ## If old_pis is not provided, use pi as old_pi
        if old_pi is None:
            old_pi = detach_distribution(pi)
        ......
        kl = weighted_mean(kl_divergence(pi, old_pi), dim=0, weights=mask)
        kls.append(kl)

    return torch.mean(torch.stack(kls, dim=0))

Since old_pis is not provided inside hessian_vector_product. This means that kl divergence is computed between new policy and itself.

I am wondering if this would lead to consistent self.kl_divergence(episodes)==0 throughout training?

Would appreciate your insight. Thanks!

I realize that old_pi has been detached from pi. This implies that even though kl==0, the gradient of kl with regard to policy.parameters() should not be zero.

However, I looked into the kl gradient computed by the following line, the strange thing is that it gives 0 gradient.

grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True)

After some more careful checks, I conclude that this is the normal behavior when policy is NormalMLP policy. More on calculating kl between normal distribution can be found here