/set_transformer

Amortized clustering with a set transformer in JAX/Flax.

Primary LanguagePythonMIT LicenseMIT

Set Transformer-based amortized clustering

The purpose of this repository is to demonstrate a basic ML workflow compatible with Stanford cluster computing resources. Specifically, this repo contains a reimplementation of Experiment 5.3 ("Amortized Clustering with Mixture of Gaussians") from

amortized_clustering_training.mp4

Sherlock (Slurm) usage

  1. (Setup) From the base directory of this repository, run
    python3 -m venv env
    source env/bin/activate
    pip3 install -r requirements-jax.txt
    pip3 install -r requirements.txt
    If the last command fails, try pip3 install -r requirements.txt --no-cache-dir.
  2. (Running) Now run
    sbatch submit.sh
    to kick off a simple hyperparameter sweep as detailed in submit.sh.
  3. (Monitoring) To check the status of your jobs, run
    squeue -u $USER
    Once you can see that they've started, you should notice job log files (of the form slurm-XXXXXXXX_X.out) and a checkpoints/ directory have been created. Observe the progress of your jobs from the command line with, e.g.,
    tail -f slurm*
    and/or navigate to https://login.sherlock.stanford.edu/ to run an OnDemand TensorBoard session (note that you will have to provide the relevant TensorBoard logdir, e.g., $HOME/set_transformer/checkpoints).

General usage

Install JAX with your preferred accelerator support and then install the rest of the dependencies with pip3 install -r requirements.txt. Train from the command line with python3 main.py; hyperparameters listed in config.py may be configured with config flags, e.g.,

python3 main.py --config.input_encoding=sinusoidal --config.learning_rate=1e-3

In interactive workflows, you may choose to directly modify the ConfigDict returned by config.get_config() for usage with the functions in train_eval.py, e.g., train_eval.train_and_evaluate.