Set Transformer-based amortized clustering
The purpose of this repository is to demonstrate a basic ML workflow compatible with Stanford cluster computing resources. Specifically, this repo contains a reimplementation of Experiment 5.3 ("Amortized Clustering with Mixture of Gaussians") from
amortized_clustering_training.mp4
Sherlock (Slurm) usage
- (Setup) From the base directory of this repository, run
If the last command fails, try
python3 -m venv env source env/bin/activate pip3 install -r requirements-jax.txt pip3 install -r requirements.txt
pip3 install -r requirements.txt --no-cache-dir
. - (Running) Now run
to kick off a simple hyperparameter sweep as detailed in
sbatch submit.sh
submit.sh
. - (Monitoring) To check the status of your jobs, run
Once you can see that they've started, you should notice job log files (of the form
squeue -u $USER
slurm-XXXXXXXX_X.out
) and acheckpoints/
directory have been created. Observe the progress of your jobs from the command line with, e.g.,and/or navigate to https://login.sherlock.stanford.edu/ to run an OnDemand TensorBoard session (note that you will have to provide the relevant TensorBoard logdir, e.g.,tail -f slurm*
$HOME/set_transformer/checkpoints
).
Install JAX with your preferred accelerator support and then install the rest of the dependencies with pip3 install -r requirements.txt
. Train from the command line with python3 main.py
; hyperparameters listed in config.py
may be configured with config flags, e.g.,
python3 main.py --config.input_encoding=sinusoidal --config.learning_rate=1e-3
In interactive workflows, you may choose to directly modify the ConfigDict
returned by config.get_config()
for usage with the functions in train_eval.py
, e.g., train_eval.train_and_evaluate
.