ruqizhang/discrete-langevin

Quesiton about the Bayesian NN implementation

Closed this issue · 2 comments

Hi,

I'm going through the code for the BayesianNN implementation, and in the constructor there is num_particles argument

def __init__(self, X_train, y_train, batch_size, num_particles, hidden_dim):
.

I don't understand what is the purpose of this argument. Is it a way of training the NN before re-sampling theta from its distribution?

P.S : Thank you sharing the code for this great work! (I'm going through the code to port it to JAX)

Hi Habush,

We use multi-chain langevin dynamics (running multiple langevin dynamics in parallel) to fully leverage the GPU. To do this, we randomly initialize num_particles networks and run Langevin dynamics with these networks independently.

That makes sense. Thanks for the response!