quantumiracle/Popular-RL-Algorithms

I think "with torch.no_grad():" is needed when calculating critic loss

dbsxdbsx opened this issue · 2 comments

In file sac_v2_lstm.py, the code below without with torch.no_grad(): when calculating target Q value:

    # Training Q Function
    #  I think `with torch.no_grad(): ` is needed
        predict_target_q1, _ = self.target_soft_q_net1(next_state, new_next_action, action, hidden_out)
        predict_target_q2, _ = self.target_soft_q_net2(next_state, new_next_action, action, hidden_out)
        target_q_min = torch.min(predict_target_q1, predict_target_q2) - self.alpha * next_log_prob
        target_q_value = reward + (1 - done) * gamma * target_q_min # if done==1, only reward

        q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach())  # detach: no gradients for the variable
        q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())

Though, for some simple environment, the algorithm would finally converge, but I am not sure the case in complex environment. As far as I know, the parameter for calculating target_q_value should not contribute any gradient to model.

Hi,
Is it just what the target_q_value.detach() is achieving?
Best,
zihan

Check here.
I think there is no difference since there is only one variable (target_q_value) involved here.