KL divergence with old policy in trpo training
Closed this issue · 2 comments
Hi,
I noticed that during trpo training, we need to compute hessian vector product to get the step direction.
The code for hessian vector product requires computing kl divergence between new and old policy
def hessian_vector_product(self, episodes, damping=1e-2):
def _product(vector):
kl = self.kl_divergence(episodes)
...... # use kl to do some computations
return _product
Below is the implementation of self.kl_divergence()
def kl_divergence(self, episodes, old_pis=None):
kls = []
if old_pis is None:
old_pis = [None] * len(episodes)
for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis):
params = self.adapt(train_episodes)
pi = self.policy(valid_episodes.observations, params=params)
## If old_pis is not provided, use pi as old_pi
if old_pi is None:
old_pi = detach_distribution(pi)
......
kl = weighted_mean(kl_divergence(pi, old_pi), dim=0, weights=mask)
kls.append(kl)
return torch.mean(torch.stack(kls, dim=0))
Since old_pis
is not provided inside hessian_vector_product
. This means that kl divergence is computed between new policy and itself.
I am wondering if this would lead to consistent self.kl_divergence(episodes)==0
throughout training?
Would appreciate your insight. Thanks!
I realize that old_pi
has been detached from pi
. This implies that even though kl==0
, the gradient of kl
with regard to policy.parameters()
should not be zero.
However, I looked into the kl gradient computed by the following line, the strange thing is that it gives 0 gradient.
grads = torch.autograd.grad(kl, self.policy.parameters(), create_graph=True)