CQL Loss
Closed this issue · 9 comments
Hi,
In the DQN version there are 2 losses of which one is commented.
I guess the commented one is what is used in the reference implementation, but I am not sure.
Could you please explain a little bit about this?
just an old comment I will delete it to not confuse others :)
Thanks, but wasn't it the equivalent to the loss implemented by "aviralkumar2907" repo (https://github.com/aviralkumar2907/CQL/blob/master/atari/batch_rl/multi_head/quantile_agent.py ) lines 222- 233 )
Well I'm not good with tensorflow but I thought they were the same in using only the chosen action for regularization not all actions.
Also in the "q1_loss = cql1_loss + 0.5 * bellmann_error", the 0.5 shouldn't be 0.1?
I guess based on the eq. 4 in the paper for q learning, since the behavior policy is greedy and selects the best action with p=1, the cql loss should be this:
cql1_loss = (torch.logsumexp(Q_a_s, dim=1) - Q_expected).mean()
Also in the "q1_loss = cql1_loss + 0.5 * bellmann_error", the 0.5 shouldn't be 0.1?
It's really 0.1? I probably missed out on that from where do you get its 0.1?
I guess based on the eq. 4 in the paper for q learning, since the behavior policy is greedy and selects the best action with p=1, the cql loss should be this: cql1_loss = (torch.logsumexp(Q_a_s, dim=1) - Q_expected).mean()
here the one with q_expected (orange) and the other version. I need to revisit the paper to give you a more detailed answer but to me the result is clear.
I could be wrong of course but I thought the loss would be (10 * cql1_loss + bellmann_error) but it does not seems right!
As for using q_expected, please do let me know what is your opinion based on the paper, but as you said the results obviously show I am wrong!
Thank you very much.
Hi, @merv22 and @BY571. I have the same confusion about the implementation of CQL loss in DQN version. According to the CQL paper, I also think it should minus Q_expected.
I guess based on the eq. 4 in the paper for q learning, since the behavior policy is greedy and selects the best action with p=1, the cql loss should be this: cql1_loss = (torch.logsumexp(Q_a_s, dim=1) - Q_expected).mean()
Hi, it seems that the implementation is wrong.
there is more info in this issue #5
apparently more training is needed with the correct loss.
Hey @yueyang130 @merv22 I updated the CQL loss sry for the late reply. DQN-CQL loss should be correct now. I also tested the changes for CartPole-v0: