blackjax-devs/blackjax

Sampling book SMC not running with latest Jax

Closed this issue · 2 comments

Describe the issue as clearly as possible:

The sampling book Tuning inner kernel parameters of SMC does not run with the latest version of Jax, because it is not aware of the refactor of #657 .

Would it be possible to update it? (I'm trying to work my way through it, will update if I figure it out)

Steps/code to reproduce the bug:

Run the adaptation in https://blackjax-devs.github.io/sampling-book/algorithms/TemperedSMCWithOptimizedInnerKernel.html

Expected result:

For the notebook to run :)

Error message:

Starts with Cannot import inner_kernel_tuning, then when fixing the import we need to use the new callable methods that are implemented in issue #657



### Blackjax/JAX/jaxlib/Python version information:

```python
>>> import blackjax; blackjax.__version__
'1.2.3'
>>> import sys; sys.version
'3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]'
>>> import jax; jax.__version__
'0.4.31'
>>> import jaxlib; jaxlib.__version__
'0.4.31'

Context for the issue:

This makes it hard to jump into blackjax via the SMC examples.

The sampling book is run against Blackjax main branch HEAD. I will cut a release to fix this.

Should be fixed by https://github.com/blackjax-devs/blackjax/releases/tag/1.2.4 now - try reinstalling from pip and rerun the notebook.