BY571/CQL

Potential typo in the CQL implementation

miladm12 opened this issue · 5 comments

Hi, I noticed a potential typo in your implementation of CQL for the Atari game. In the file "/CQL/CQL-DQN/agent.py", in line 52, you subtract Q_a_s.mean() from the first term, though from the formulations in the original paper and the original implementation in tensorflow (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py), this term needs to be weighted based on the actual actions in the mini-batch. Since you already calculate the Q_expected, you just need to replace Q_a_s with Q_expected. So line 52 will become:

cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()

Please let me know if I'm wrong, but I did double check this with the source code and the formulations in the paper.

I have the same question as you.
I try to change the code but i get a bad result.
截屏2022-08-22 22 41 56
I get confused about the results.

I changed it to the following and it is working fine for me:

Q_a_s = self.net(states)
Q_expected = Q_a_s.gather(1, actions)
cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()

I run it for more steps, and it seems it works.
截屏2022-08-29 15 20 45
In principle, we should be right, but more steps are needed.
If do not fix it, the code is more similar to DQN.

Hi, I noticed a potential typo in your implementation of CQL for the Atari game. In the file "/CQL/CQL-DQN/agent.py", in line 52, you subtract Q_a_s.mean() from the first term, though from the formulations in the original paper and the original implementation in tensorflow (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py), this term needs to be weighted based on the actual actions in the mini-batch. Since you already calculate the Q_expected, you just need to replace Q_a_s with Q_expected. So line 52 will become:

cql_loss = torch.logsumexp(Q_a_s, dim=1).mean() - Q_expected.mean()

Please let me know if I'm wrong, but I did double check this with the source code and the formulations in the paper.

I also checked with the reference implementation.
This implementation has wrong cql loss calculation.
Softmax part should include all action while expected part should include only actions found in the dataset.

This repo has some stars which makes it visible and needs to corrected.

BY571 commented

Hey @habanoz @jialianchen @miladm12 ! Sry for the late response, I just updated the CQL-DQN loss and it should be correct now. Also tested its performance on CartPole-v0:
image