Q: Normal Distribution Compatibility with Reparameterization Trick?
Closed this issue · 2 comments
I am interested in using Optax's stochastic gradient estimators and control variates with FlowJAX. In particular, I am interested in compatibility with the reparameterization gradient (aka pathwise estimator).
The reparameterization gradient requires the "reparameterization trick" to compute the gradient of an expectation. For a normally distributed variable
Are the Normal
and MultivariateNormal
distributions in FlowJAX compatible with the reparameterization trick by default when using jax.grad
? I believe the answer is yes because they both take the StandardNormal
distribution (equivalent to z
above) and transform it with some (affine?) bijection.
I was wondering if @danielward27 can please confirm if this is true. If so, it may be worth mentioning somewhere in the docs. Thank you!
I haven't explicitly tested, but I would expect all the distributions thus far in FlowJAX to support reparameterized gradients, as the sampling operations become differentiable deterministic functions after setting the key. An example:
import jax.random as jr
from flowjax.distributions import Normal
import equinox as eqx
@eqx.filter_grad
def sample_model(dist, **kwargs):
return dist.sample(**kwargs)
dist = Normal()
grad = sample_model(dist, key=jr.key(0))
assert grad.bijection.loc == 1
You can also get the gradient of the scale parameter with grad.bijection.scale.args[0]
, but as the scale parameter is constrained to be positive, this corresponds to the gradient of the sample w.r.t. the unconstrained scale representation (prior to applying softplus). This is generally what we want for optimization as e.g. we to perform updates without risking invalid values. It's worth noting that generally, JAX itself supports reparameterized gradients too e.g.
import jax
@jax.grad
def sample_beta(a, **kwargs):
return jr.beta(a=a, **kwargs)
sample_beta(0.1, b=0.2, key=jr.key(0))
So generally, FlowJAX isn't doing anything clever to support it, but inheriting this property from JAX. There are cases where reparameterized gradients are not possible (e.g. discrete distributions or non-differentiable functions), but AFAIK, currently, all the distributions (and flows) in FlowJAX are naturally compatible with reparameterized gradients.
I'll leave this open for now, and would be happy to take a pull request for improving documentation. Maybe it's worth adding an example like the normal example above to the "Distributions and Bijections as PyTrees" section. I'll likely get around to it at some point regardless.
Thanks so much for the thorough response! Your explanation all makes sense.
Closing this now and thanks again for this incredible library.