Possible error in critic update in SAC-AE algorithm
Cerphilly opened this issue · 2 comments
Cerphilly commented
In SAC-AE algorithm, critic1 and 2 are updated as the following:
target_q = tf.stop_gradient(
rewards + not_dones * self.discount * (min_next_target_q - self.alpha * next_logps))
obs_features = self._encoder(obses, stop_q_grad=self._stop_q_grad)
current_q1 = self.qf1(obs_features, actions)
current_q2 = self.qf2(obs_features, actions)
td_loss_q1 = tf.reduce_mean((target_q - current_q1) ** 2)
td_loss_q2 = tf.reduce_mean((target_q - current_q2) ** 2) # Eq.(6)
q1_grad = tape.gradient(td_loss_q1, self._encoder.trainable_variables + self.qf1.trainable_variables)
self.qf1_optimizer.apply_gradients(
zip(q1_grad, self._encoder.trainable_variables + self.qf1.trainable_variables))
q2_grad = tape.gradient(td_loss_q2, self._encoder.trainable_variables + self.qf2.trainable_variables)
self.qf2_optimizer.apply_gradients(
zip(q2_grad, self._encoder.trainable_variables + self.qf2.trainable_variables))
However, as encoder is optimized with q1 before q2 + encoder optimization, td_loss_q2 and q2_grad are inconsistent. Thus I believe q2_grad have to be calculated before optimizing qf1 and encoder.
keiohta commented
Hi @Cerphilly , thanks for pointing this out! I agree that this might be an error (not sure about the impact of this though). So, would you suggest something like the following?
q_grad = tape.gradient(td_loss_q1 + td_loss_q2, self._encoder.trainable_variables + self.qf1.trainable_variables + self.qf2.trainable_variables)
self.qf_optimizer.apply_gradients(
zip(q_grad, self._encoder.trainable_variables + self.qf1.trainable_variables + self.qf2.trainable_variables))
The above code just sums up the two TD losses and computes the gradients of it.
Cerphilly commented
Thanks for the quick response!
I changed my code as the following:
target_q = tf.stop_gradient(r + self.gamma * (1 - d) * (target_min_aq - self.alpha.numpy() * ns_logpi))
with tf.GradientTape(persistent=True) as tape1:
critic1_loss = tf.reduce_mean(tf.square(self.critic1(self.encoder(s), a) - target_q))
critic2_loss = tf.reduce_mean(tf.square(self.critic2(self.encoder(s), a) - target_q))
critic1_gradients = tape1.gradient(critic1_loss,
self.encoder.trainable_variables + self.critic1.trainable_variables)
critic2_gradients = tape1.gradient(critic2_loss,
self.encoder.trainable_variables + self.critic2.trainable_variables)
self.critic1_optimizer.apply_gradients(
zip(critic1_gradients, self.encoder.trainable_variables + self.critic1.trainable_variables))
self.critic2_optimizer.apply_gradients(
zip(critic2_gradients, self.encoder.trainable_variables + self.critic2.trainable_variables))
and it seemed to achieve higher performance in RAD.
But your suggestion also seems to work well without error.