question on SAC for discrete action with temperature loss "alpha"
dbsxdbsx opened this issue · 1 comments
dbsxdbsx commented
According to paper: SOFT ACTOR-CRITIC FOR DISCRETE ACTION SETTINGS, the loss formula (12) for temperature , there is api_t(s_t)
at the very outside of the formula, which is action_probs for discrete actions. But according to the code here:
# from SAC_Discrete.py
def calculate_actor_loss(self, state_batch):
"""Calculates the loss for the actor. This loss includes the additional entropy term"""
action, (action_probabilities, log_action_probabilities), _ = self.produce_action_and_action_info(state_batch)
qf1_pi = self.critic_local(state_batch)
qf2_pi = self.critic_local_2(state_batch)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
inside_term = self.alpha * log_action_probabilities - min_qf_pi
policy_loss = action_probabilities * inside_term
policy_loss = policy_loss.mean()
log_action_probabilities = torch.sum(log_action_probabilities * action_probabilities, dim=1)
return policy_loss, log_action_probabilities
# from SAC.py
def calculate_entropy_tuning_loss(self, log_pi):
"""Calculates the loss for the entropy temperature parameter. This is only relevant if self.automatic_entropy_tuning
is True."""
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
return alpha_loss
Here log_action_probabilities = torch.sum(log_action_probabilities * action_probabilities, dim=1)
seems that only log_action_probabilities
is affected by action_probabilities
, but not with self.target_entropy
, why?
dbsxdbsx commented
I've solved it here: yining043/SAC-discrete#2 (comment)