danielward27/flowjax

Autodiff problem with block_neural_autoregressive_flow

Opened this issue · 2 comments

Very nice package :-)

While playing around with different optimization objectives I ran into an autodiff issue.
The following always returns exactly zeros, which I think isn't correct. This might be related to the bisection search (I think that's what's used here?) but if the while_loop in there is a problem I would have expected an error, not an incorrect result.

import flowjax.flows
import jax.numpy as jnp
import jax
import numpy as np

flow_key = jax.random.PRNGKey(0)
point = np.random.randn(5)
cotan = np.random.randn(5)

base_dist = flowjax.distributions.Normal(jnp.zeros(5))
flow = flowjax.flows.block_neural_autoregressive_flow(flow_key, base_dist=base_dist, invert=True)

out, pull_grad_fn = jax.vjp(lambda x: flow.bijection.transform_and_log_det(x), point)
pullback = pull_grad_fn((cotan, 1.))
pullback
# (Array([0., 0., 0., 0., 0.], dtype=float32),)

Interesting, I would have initially assumed it would have errored too. Here's an example that gets to the root
of the problem (in both senses of the word)

import jax.numpy as jnp
from jax import lax
import jax

def _bisection_search(func, *, lower, upper, tol: float, max_iter: int):

    def cond_fn(state):
        lower, upper, iterations = state
        return jnp.logical_and((upper - lower) > 2 * tol, iterations < max_iter)

    def body_fn(state):
        lower, upper, iterations = state
        midpoint = (lower + upper) / 2
        sign = jnp.sign(func(midpoint))
        lower = jnp.where(sign == 1, lower, midpoint)
        upper = jnp.where(sign == 1, midpoint, upper)
        return lower, upper, iterations + 1

    init_state = (lower, upper, 0)
    lower, upper, iterations = lax.while_loop(cond_fn, body_fn, init_state)
    root = (lower + upper) / 2
    return root, iterations

def get_root(x):
    return _bisection_search(
        func=lambda arr: arr + x,
        lower=-10,
        upper=10,
        tol =1e-7,
        max_iter=100,
    )[0]

Note the gradient is actually zero everywhere it is defined, because the result is a stepwise function. We can see that visually by plotting it for a very small region

x = jnp.linspace(-1e-6, 1e-6, 1000)
roots = jax.vmap(get_root)(x)

import matplotlib.pyplot as plt
plt.plot(x, roots)

Regardless, I think that an error when differentiating through _bisection_search might be better than returning a gradient of zero, because it's likely a mistake. I presume this could be done with jax.custom_jvp?

Probably a better solution would be to use the implicit function theorem, although I'm unlikely to be able to have a look at that soon (pull requests would be welcome!)