/DMALAX

Discrete Metropolis-Adjusted Langevin Algorithm in Jax!

Primary LanguagePythonApache License 2.0Apache-2.0

DMALAX - Discrete Metropolis-Adjusted Langevin Algorithm in JAX

This repository contains the code for the JAX based implementation of the work by Zhang et.al 2022 titled A Langevin-like Sampler for Discrete Distributions. It's design is heavily insipred by blackjax even borrowing some code from the api. (I implemented this code in a separate repo as part of learning how samplers for discrete distributions work, and I plan to send a PR to the official blackjax repo 🤞)

Usage

Please check the notebooks in examples directory for how to use the kernel

Todo

  • Extend the kernel for Categorical distributions. Currently only binary-valued distributions are supported
  • Add more example notebooks that implement:
    • Potts Model
    • Restricted-Boltzmann Machine (RBM) Model
    • Bayesian Neural Network (BNN)

Other versions

You can find the PyTorch implementation of the paper written by the authors themselves here -> https://github.com/ruqizhang/discrete-langevin/