google-deepmind/distrax

`nan` in MultivariateNormalDiag log prob

Opened this issue · 2 comments

Hello thanks for this awesome repo! We have had a slight issue with using distrax which creates nan at vwxyzjn/cleanrl#300. See the following reproduction script:

from typing import Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

# import pybullet_envs  # noqa
import tensorflow_probability
from flax.training.train_state import TrainState

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions
jax.config.update("jax_platform_name", "cpu")
import distrax


class Actor(nn.Module):
    action_dim: Sequence[int]
    n_units: int = 256
    log_std_min: float = -20
    log_std_max: float = 2

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        mean = nn.Dense(self.action_dim)(x)
        log_std = nn.Dense(self.action_dim)(x)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std


# @jax.jit
def custom_log_prob(
    mean: jnp.ndarray,
    log_std: jnp.ndarray,
    subkey: jax.random.KeyArray,
    gaussian_action: jnp.ndarray,
):
    std = jnp.exp(log_std)
    gaussian_action = mean + std * jax.random.normal(subkey, shape=mean.shape)
    log_prob = -0.5 * ((gaussian_action - mean) / std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - log_std
    log_prob = log_prob.sum(axis=1)
    # https://github.com/vwxyzjn/cleanrl/pull/300#issuecomment-1326285592
    log_prob -= jnp.sum(2.0 * (np.log(2.0) - gaussian_action - jax.nn.softplus(-2.0 * gaussian_action)), 1)
    return log_prob


if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    key, actor_key = jax.random.split(key, 2)
    # with open("test.npy", "rb") as f:
    #     obs = np.load(f)
    obs = jnp.array([[ -0.06284985,  -0.0164921 ,  -0.10846169,   0.28114545,
         -0.28463456,   0.4503281 ,   0.27488193,  -0.0666963 ,
          0.6118138 ,   0.34202537,  -1.262452  ,   0.7542422 ,
         13.809639  ,  -0.6205632 ,  -4.0013294 ,   5.3532414 ,
         11.587792  ],
       [ -0.15303956,   0.9534635 ,  -0.3092537 ,  -0.2033926 ,
          0.03336933,   0.6362027 ,   0.02348915,  -0.32627296,
         -0.29046476,   0.46484601,  -0.42002085,  -3.1616204 ,
          2.247283  ,  14.114895  ,   2.6248324 ,  -1.9809983 ,
        -12.693646  ],
       [ -0.07995494,   0.09804074,  -0.20460981,  -0.13476144,
          0.1701505 ,   0.05989099,  -0.06446445,  -0.22749065,
          0.39946172,   0.42318228,   2.5876977 ,   3.8510017 ,
         -8.23167   ,  -7.292657  ,   7.64345   ,  -9.558817  ,
         -1.9690503 ],
    ])
    # obs = obs[0:5]
    actor = Actor(action_dim=6)
    actor_state = TrainState.create(
        apply_fn=actor.apply,
        params=actor.init(actor_key, obs),
        tx=optax.adam(learning_rate=3e-4),
    )

    key, subkey = jax.random.split(key, 2)
    mean, log_std = actor.apply(actor_state.params, obs)
    action_std = jnp.exp(log_std)
    tfd_dist = tfd.TransformedDistribution(
        tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=tfp.bijectors.Tanh()
    )
    distrax_dist = distrax.Transformed(
        distrax.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=distrax.Block(distrax.Tanh(), 1)
    )

    # action generation
    gaussian_action = mean + action_std * jax.random.normal(subkey, shape=mean.shape)
    action_custom = jnp.tanh(gaussian_action)
    reverse_action_custom = jnp.arctanh(action_custom)
    action_tfp = tfd_dist.sample(seed=subkey)
    action_distrax = distrax_dist.sample(seed=subkey)

    print("action_custom.sum()", action_custom.sum())
    print("action_tfp.sum()", action_tfp.sum())
    print("action_distrax.sum()", action_distrax.sum())
    print("gaussian_action.sum()", gaussian_action.sum())
    print("reverse_action_custom.sum()", reverse_action_custom.sum())

    # log_prob
    for idx, (action, name) in enumerate(
        zip([action_custom, action_tfp, action_distrax], ["action_custom", "action_tfp", "action_distrax"])
    ):
        log_prob_custom = custom_log_prob(mean, log_std, subkey, jnp.arctanh(action))
        log_prob_tfp = tfd_dist.log_prob(action)
        log_prob_distrax = distrax_dist.log_prob(action)
        print(name)
        print("┣━━ log_prob_custom.sum()", log_prob_custom.sum())
        print("┣━━ log_prob_tfp.sum()", log_prob_tfp.sum())
        print("┣━━ log_prob_distrax.sum()", log_prob_distrax.sum())
action_custom.sum() 2.8352258
action_tfp.sum() 5.978534
action_distrax.sum() 2.8352258
gaussian_action.sum() 34.332348
reverse_action_custom.sum() inf
action_custom
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan
action_tfp
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() 60.565056
┣━━ log_prob_distrax.sum() nan
action_distrax
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan

Also ran into Tanh bijector + Transformed causing NaNs. #7 has a workaround.

Did some digging into this because it was really bothering me and turns out the behaviour seems somewhat expected / it's not really distrax' fault. I think the conclusion I came to is pretty much what this comment in #7 describes as well, but perhaps it'd be worth documenting here in greater detail since this issue is still open.

If you print the sampled actions in this code snippet rather than their sum, you will notice that specifically at index [0,0] the value is 1.0000001. Then, calling jnp.arctanh() as part of the inverse process of the Tanh bijector, you get a nan.

Obviously such a value is outside the range of tanh and shouldn't occur but it does because of numerical precision. Switching the precision to 64bit with jax.config.update("jax_enable_x64", True), you don't get such values and the code snippet works fine.

As a sidenote, the reason custom_log_prob() returns a value here is because it doesn't actually take the arctanh() of the sampled action. If you look closely at the snippet, the actual function discards the gaussian_action argument it takes, and reinitialises it by sampling from a normal distribution, which is wrong (it only works if the same rng key was used for the actions whose log prob is being computed). If you cut that line out, it too returns nan just like tfp and distrax.

Therefore, this isn't something that can be fixed on distrax' end. The reason #7's workaround works is because it computes the log prob of the sampled actions using the pre-tanh value (which is readily available since the operation includes a forward sampling pass) and the numerical precision never becomes a problem. Calling log_prob() on pre-sampled actions, however, (so the pre-tanh value isn't readily available), requires a call to arctanh() and results in the problem above unless 64-bit precision is used.

To conclude, the ways around this I can think of are to either:

  • use 64-bit precision,
  • not use log_prob() and only use sample_and_log_prob() (which depending on your use case might actually be possible, e.g. in RL for SAC)
  • only use MultivariateNormalDiag, store the pre-tanh values, then compute actions as actions = jnp.tanh(pre_tanh_actions) and the log probs as:
    log_prob = normal_dist.log_prob(pre_tanh_actions) - jnp.sum(2 * (jnp.log(2) - pre_tanh_actions - jax.nn.softplus(-2 * pre_tanh_actions)), axis=-1)
  • make sure the values for mean and std vastly decrease the probability of the problem occurring, e.g. use a smaller log_std_max in your network or use any other tools that would bound the mean and std to values that lead to more reasonable numbers.