Code for Kansal et. al., Particle Cloud Generation with Message Passing Generative Adversarial Networks, NeurIPS 2021 arXiv:2106.11535
.
This repository contains PyTorch code for the message passing GAN (MPGAN) model, as well as scripts for training the models from scratch, generating and plotting the particle clouds. We include also weights of fully trained models discussed in the paper.
Additionally, we release the standalone JetNet library, which provides a PyTorch Dataset class for our JetNet dataset, implementations of the evaluation metrics discussed in the paper, and some more useful utilities for development in machine learning + jets.
torch >= 1.8.0
torch >= 1.8.0
jetnet >= 0.1.0
numpy >= 1.21.0
matplotlib
mplhep
torch
torch_geometric
A Docker image containing all necessary libraries can be found here (Dockerfile).
Start training with:
python train.py --name test_model --jets g [args]
By default, model parameters, figures of particle and jet features, and plots of the training losses and evaluation metrics over time will be saved every five epochs in an automatically created outputs/[name]
directory.
Some notes:
- Will run on a GPU by default if available.
- The default arguments correspond to the final model architecture and training configuration used in the paper.
- Run
python train.py --help
or look at setup_training.py for a full list of arguments.
Pre-trained generators with saved state dictionaries and arguments can be used to generate samples with, for example:
python gen.py --G-state-dict trained_models/mp_g/G_best_epoch.pt --G-args trained_models/mp_g/args.txt --num-samples 50,000 --output-file trained_models/mp_g/gen_jets.npy