aviralkumar2907/CQL

Potential mismatch between math and code for CQL(rho)

zhihanyang2022 opened this issue ยท 2 comments

This is a question regarding how CQL(rho) works in terms of code ๐Ÿ˜Š.

In the CQL section (starting from line 235) within /CQL/d4rl/rlkit/torch/sac/cql.py, we first computed:

cat_q1 = torch.cat(
    [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
)
cat_q2 = torch.cat(
    [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
)

and then used them to compute

min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp

I'm a bit confused about why the Q values of actions drawn from three distinct distributions can be used to compute this quantity:

  • q1_rand: uniform distribution
  • q1_pred: dataset distribution
  • q1_curr_actions and q1_next_actions: last-iteration policy

Here are my questions:

  • In Appendix A section CQL(rho), don't we have that the expectation is with respect to the rho distribution only (which we have chosen to be the last-iteration policy)?
  • Why do we use log-sum-exp here while the corresponding term (the first term) in Equation 7 of the paper does not contain log at all?

I'm able to completely understand how CQL(H) works in the codebase though.

I think they only gave the implementation of CQL(H). In their code base, the min_q_version is always set to 3, which corresponds to CQL(H). The equation with log-sum-exp is present in Appendix F (Additional Experimental Setup and Implementation Details).

Equation 7 missed some item, i.e. the KL-divergence, after adding this item any you can deduce logsumexp