I think "with torch.no_grad():" is needed when calculating critic loss
dbsxdbsx opened this issue · 2 comments
dbsxdbsx commented
In file sac_v2_lstm.py
, the code below without with torch.no_grad():
when calculating target Q value:
# Training Q Function
# I think `with torch.no_grad(): ` is needed
predict_target_q1, _ = self.target_soft_q_net1(next_state, new_next_action, action, hidden_out)
predict_target_q2, _ = self.target_soft_q_net2(next_state, new_next_action, action, hidden_out)
target_q_min = torch.min(predict_target_q1, predict_target_q2) - self.alpha * next_log_prob
target_q_value = reward + (1 - done) * gamma * target_q_min # if done==1, only reward
q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach()) # detach: no gradients for the variable
q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())
Though, for some simple environment, the algorithm would finally converge, but I am not sure the case in complex environment. As far as I know, the parameter for calculating target_q_value
should not contribute any gradient to model.
quantumiracle commented
Hi,
Is it just what the target_q_value.detach()
is achieving?
Best,
zihan
quantumiracle commented
Check here.
I think there is no difference since there is only one variable (target_q_value) involved here.