Curt-Park/rainbow-is-all-you-need

redundant max in double dqn

DongukJu opened this issue · 4 comments

In double dqn, I found that there is max Q(~~~, argmaxQ(~~~)).

Do we need max even though we have argmax in Q?

I think the max is redundant.

Would you kindly check this for reducing confusion?

Hi. @DongukJu

As you said, equation 1 and 2 mean same thing. But equation 2 is a variation of the expression to transform it into double Q-learning. As you look at equation 3 closely, the two Q values are different two thetas (theta, theta '). Double Q-learning is a way to reduce over-estimation by updating each other using two Q-values. To explain the process of change, therefore, it would be better to use the expression as it is.

equation 1.
image

equation 2.
image

equation 3. image

Thank you!

Dear @MrSyee,

Thanks for your reply.

As you mentioned, I agree with the importance of the variation.

What I meant was that,
image

max_a in eq 2 and eq 3 is redundant.

As far as I understand correctly, the max is doing nothing there.

If you want to convey the same intuition to the readers, we might allow this redundancy, but it may still cause confusion.

The max is supposed to do something, but nothing.

def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
    """Return dqn loss."""
    device = self.device  # for shortening the following lines
    state = torch.FloatTensor(samples["obs"]).to(device)
    next_state = torch.FloatTensor(samples["next_obs"]).to(device)
    action = torch.LongTensor(samples["acts"].reshape(-1, 1)).to(device)
    reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
    done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
    
    # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
    #       = r                       otherwise
    curr_q_value = self.dqn(state).gather(1, action)
    next_q_value = self.dqn_target(next_state).gather(  # Double DQN
        1, self.dqn(next_state).argmax(dim=1, keepdim=True)
    ).detach()
    mask = 1 - done
    target = (reward + self.gamma * next_q_value * mask).to(self.device)

    # calculate dqn loss
    loss = F.smooth_l1_loss(curr_q_value, target)

    return loss

Again, in your code, there is only one argmax for next_q_value, not max and argmax.

Would you kindly clarify this?

Dear @DongukJu

Oh. You're right. I'm sorry that I couldn't figure out the typo even after reading your comment.
I'll modify this problem.

Thanks for your insight.

The typo is fixed. Thanks.