ikostrikov/pytorch-a2c-ppo-acktr-gail

Possible bug on the sign of policy log prob. in Fisher computation

Opened this issue · 0 comments

Dear @ikostrikov ,

while reading your code I noticed that you use the log prob. of a normal distribution for the Fisher matrix calculation in the value loss, but the negative log prob. of the policy. Comparing your code in [1] with the equivalent lines of the stable baselines [2], one can see that the policy part of the Fisher matrix calculation is a log prob. (minus negative log prob. in pg_fisher_loss) and the value function contribution is also a log prob. (minus mean squared error).

The original paper mentions the construction of the Fisher matrix using the gradient of the log prob. of the policy and the log prob. of a Gaussian around the value function (section 3.1 of [3]). I would expect therefore that the sign of the two terms used for the Fisher matrix to follow the same convention, as it is done in the stable baselines repository. The actual loss function minimisation is done with a negative log prob. for both (as you currently do, and as it is done in the stable baselines repo.), but in both cases, the sign of the two terms should be consistent.

Therefore, I could not fully understand the reason for that sign in the fisher matrix calculation. Is this a bug, or is there some deeper reason behind it?

Best regards,
Danilo

[1] https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/algo/a2c_acktr.py#L53
[2] https://stable-baselines.readthedocs.io/en/master/_modules/stable_baselines/acktr/acktr.html#ACKTR
[3] https://arxiv.org/pdf/1708.05144.pdf