AI4Finance-Foundation/ElegantRL

A qusetion about the code of 'ActorSAC' class in net.py

Opened this issue · 1 comments

I'm confused why we use 'logprob = dist.log_prob(a_avg)' instead of 'logprob = dist.log_prob(action)' in line 247 of elegantrl/agents/net.py. I think the latter is consistent to the original paper. Is using the former better in experiment?

def get_action_logprob(self, state):
        state = self.state_norm(state)
        s_enc = self.net_s(state)  # encoded state
        a_avg, a_std_log = self.net_a(s_enc).chunk(2, dim=1)
        a_std = a_std_log.clamp(-16, 2).exp()
        dist = Normal(a_avg, a_std)
        action = dist.rsample()
        action_tanh = action.tanh()
        logprob = dist.log_prob(a_avg)
        logprob -= (-action_tanh.pow(2) + 1.000001).log()  # fix logprob using the derivative of action.tanh()
        return action_tanh, logprob.sum(1)

It is better.

You can read the webpage below for more information.

Update tanh bijector with numerically stable formula.
tensorflow/probability@ef6bb17#diff-e120f70e92e6741bca649f04fcd907b7

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
        return 2. * (math.log(2.) - x - F.softplus(-2. * x))

https://github.com/AI4Finance-Foundation/ElegantRL/blob/dee9c6d095001bf8365c0359f0d04a021d8c1e22/elegantrl/agents/net.py

ElegantRL code comment

"""fix log_prob of action.tanh"""
log_prob += (
np.log(2.0) - a_noise - self.soft_plus(-2.0 * a_noise)
) * 2.0 # better than below
"""same as below:
epsilon = 1e-6
a_noise_tanh = a_noise.tanh()
log_prob = log_prob - (1 - a_noise_tanh.pow(2) + epsilon).log()
Thanks for:
https://github.com/denisyarats/pytorch_sac/blob/81c5b536d3a1c5616b2531e446450df412a064fb/agent/actor.py#L37
↑ MIT License, Thanks for https://www.zhihu.com/people/Z_WXCY 2ez4U
They use action formula that is more numerically stable, see details in the following link
https://pytorch.org/docs/stable/_modules/torch/distributions/transforms.html#TanhTransform
https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f
"""