The code has been re-adapted by [github.com/google/wide_bnn_sampling]/(https://github.com/google/wide_bnn_sampling).
The contributions to the code regards the addition of the preconditioned Crank-Nicholson (pCN) and preconditioned Crank-Nicholson Langevin (pCNL) samplers, see samplers.py
for the implementation.
The dependencies are in the file setup.py
and can be installed running
git clone https://github.com/google/wide_bnn_sampling
cd wide_bnn_sampling
pip install -e .
Note that jaxlib
is also required and needs a specific installation based on the hardware. Refer to: JAX's repository.
config.py
: contains the parameters and the specifications for the experimentsdatasets.py
: manage the data (loading and preprocessing)main.py
: core of the experimental procedure.measurements.py
:models.py
: define the neural networks architectures (FCN and ResNet).reparametrisation.py
: implement the reparametrisation of the posterior of the weights (see Hron et al. for more details)samplers.py
: contains the implementations of the used MCMC procedures, specifically the Hamiltonia Monte Carlo/Langevin Monte Carlo, pCN, pCNL, Metropolis-Hastings with a simple Gaussian proposal.utils.py
: auxiliary methods.