Integration with Inference Gym
reubenharry opened this issue · 6 comments
Since Blackjax is a repository of inference algorithms with a fairly uniform interface, it would be nice if there was a fairly automated procedure for benchmarking and comparing different algorithms, using the models from inference gym.
What I'm envisioning is a helper function that would take a set of SamplingAlgorithm
s, let's say, run each on all of the inference gym problems, and then report some useful metrics, e.g. bias vs wallclock time (or number of gradient calls), plotted in a graph, so that it's as easy as possible to assess the performance of a given sampling algorithm.
Does something like this sound of interest? Has it already been done? If yes to the former, and no to the latter, it's something I'd be interested in contributing (a similar bit of code is in https://github.com/JakobRobnik/MicroCanonicalHMC/tree/master and could be ported over).
Why have this in Blackjax
- It would incentivize researchers to implement their algorithms in Blackjax, since algorithm comparison would then be streamlined.
- It would be useful to have more benchmarks for the existing algorithms
Follow up questions
Is inference gym available in jax? https://github.com/JakobRobnik/MicroCanonicalHMC manually ports it, but that wouldn't stay up to date if inference-gym were changed.
+1 to the suggestion. I think we should set up a new repository for that.
OK, I'll look into that. For now I might proceed by working in this repo, and then we can discuss splitting it out down the road based on the PR.
Basically what I'm thinking of is porting @JakobRobnik's code (https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/benchmarks/error.py), and making a function assess_sampling_algorithm
, that gives ess per sample across the different models. Nothing too fancy.
One reason to do this is that currently in Blackjax, it's not totally clear where to look to find the optimal strategy for both tuning and running a given algorithm. So for example, it would be nice for this to serve as a place to see how to run each algorithm with the best possible tuning.
So for example, it would be nice for this to serve as a place to see how to run each algorithm with the best possible tuning
For what it is worth, I'm very interested in this for https://github.com/jax-ml/bayeux, which follows some heuristics, but would welcome specifics.
OK, having investigated a little more, some notes (for myself, mainly). What I want is a function type InferenceGymModel -> ArrayOfSamples
, i.e. a function that takes an inference gym model and produces samples (of the appropriate dimensions, although Python's type system is too puny to enforce this), with the recommended adaptation, preconditioning, etc etc.
Then I want each of the sampling algorithms that is going be benchmarked in blackjax to have a corresponding function of this type, in a example_usage
directory or similar.
Then I'll have a function (InferenceGymModel, ArrayOfSamples) -> ESS_Estimate
(already written), and it will be straightforward to compute ESS for each pair of inference method and model, which I will plot in a pretty graph (ideally run by CI).
It's possible this overlaps a little with Bayeux, in which case maybe there's some common ground to exploit.
Some todos ad Divij and I work on this:
- add NUTS as an inference method to SamplingAlgorithms
- work out the best way to add various expectations (like <x^2>) to the inference gym classes
- work out the best way to add the ability to sample from the distribution to each inference gym model (basically as done in https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/benchmarks/targets.py but compatible with the inference gym classes, so maybe via a subclass)
- prepare a set of inference gym models (e.g. corresponding to https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/benchmarks/targets.py)