Potential typo in the CQL implementation
miladm12 opened this issue · 5 comments
Hi, I noticed a potential typo in your implementation of CQL for the Atari game. In the file "/CQL/CQL-DQN/agent.py", in line 52, you subtract Q_a_s.mean() from the first term, though from the formulations in the original paper and the original implementation in tensorflow (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py), this term needs to be weighted based on the actual actions in the mini-batch. Since you already calculate the Q_expected, you just need to replace Q_a_s with Q_expected. So line 52 will become:
cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()
Please let me know if I'm wrong, but I did double check this with the source code and the formulations in the paper.
I changed it to the following and it is working fine for me:
Q_a_s = self.net(states)
Q_expected = Q_a_s.gather(1, actions)
cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()
Hi, I noticed a potential typo in your implementation of CQL for the Atari game. In the file "/CQL/CQL-DQN/agent.py", in line 52, you subtract Q_a_s.mean() from the first term, though from the formulations in the original paper and the original implementation in tensorflow (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py), this term needs to be weighted based on the actual actions in the mini-batch. Since you already calculate the Q_expected, you just need to replace Q_a_s with Q_expected. So line 52 will become:
cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()
Please let me know if I'm wrong, but I did double check this with the source code and the formulations in the paper.
I also checked with the reference implementation.
This implementation has wrong cql loss calculation.
Softmax part should include all action while expected part should include only actions found in the dataset.
This repo has some stars which makes it visible and needs to corrected.
Hey @habanoz @jialianchen @miladm12 ! Sry for the late response, I just updated the CQL-DQN loss and it should be correct now. Also tested its performance on CartPole-v0: