kazewong/flowMC

Update dependencies version

kazewong opened this issue · 2 comments

Currently, some of the jax libraries are pinned to these versions:

jax==0.4.1
jaxlib==0.4.1
flax==0.6.3

I think last time I check there were some conflicts between newer versions of package, especially distrax is not synced to newer version of Jax. Might worth checking whether it is possible to sync flowMC to version newer than 0.4.8.

I managed to make it run with jax==0.4.11, jaxlib== 0.4.10 (needed for running Jax on Metal for Apple Silicon) and flax==0.6.10. From my tests an old version of flax (<0.6.5) does not work with recent jax releases. The only thing I had to change in the package is in RQSpline class, had to remove the jit in the definition of the RQSpline.vmap_call method. No problem if I run the example with RealNVP.

@charlyandral Great to know! RQSpline uses the bijector from Distrax, whereas the RealNVP is custom-built. So this confirms indeed distrax is the problem here.