This repository contains the code for the JAX based implementation of the work by Zhang et.al 2022 titled A Langevin-like Sampler for Discrete Distributions. It's design is heavily insipred by blackjax even borrowing some code from the api. (I implemented this code in a separate repo as part of learning how samplers for discrete distributions work, and I plan to send a PR to the official blackjax repo 🤞)
Please check the notebooks in examples
directory for how to use the kernel
- Extend the kernel for
Categorical
distributions. Currently only binary-valued distributions are supported - Add more example notebooks that implement:
- Potts Model
- Restricted-Boltzmann Machine (RBM) Model
- Bayesian Neural Network (BNN)
You can find the PyTorch implementation of the paper written by the authors themselves here -> https://github.com/ruqizhang/discrete-langevin/