Sampling book SMC not running with latest Jax
Closed this issue · 2 comments
Describe the issue as clearly as possible:
The sampling book Tuning inner kernel parameters of SMC does not run with the latest version of Jax, because it is not aware of the refactor of #657 .
Would it be possible to update it? (I'm trying to work my way through it, will update if I figure it out)
Steps/code to reproduce the bug:
Run the adaptation in https://blackjax-devs.github.io/sampling-book/algorithms/TemperedSMCWithOptimizedInnerKernel.html
Expected result:
For the notebook to run :)
Error message:
Starts with Cannot import inner_kernel_tuning
, then when fixing the import we need to use the new callable methods that are implemented in issue #657
### Blackjax/JAX/jaxlib/Python version information:
```python
>>> import blackjax; blackjax.__version__
'1.2.3'
>>> import sys; sys.version
'3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]'
>>> import jax; jax.__version__
'0.4.31'
>>> import jaxlib; jaxlib.__version__
'0.4.31'
Context for the issue:
This makes it hard to jump into blackjax via the SMC examples.
The sampling book is run against Blackjax main branch HEAD. I will cut a release to fix this.
Should be fixed by https://github.com/blackjax-devs/blackjax/releases/tag/1.2.4 now - try reinstalling from pip and rerun the notebook.