/SPECTRE

Reference implementation for SPECTRE: Spectral Conditioning Helps to Overcome the Expressivity Limits of One-shot Graph Generators (ICML 2022)

Primary LanguagePythonMIT LicenseMIT

SPECTRE

Reference implementation for SPECTRE: Spectral Conditioning Helps to Overcome the Expressivity Limits of One-shot Graph Generators (ICML 2022)

Setup

To run the code install all the necessary packages via the conda environment named SPECTRE:

conda env create -f environment.yml

Code structure

full_gan.py holds the main PyTorch Lightning model

model folder holds individual generators, discriminators, and their internal components

test.py tests the model. After training a model supply the appropriate checkpoint to this script. It also accepts --EMD flag to compute the MMDs using earth movers distance following GraphRNN instead of the much faster Gaussian TV kernel used by GRAN.

If orbital count MMD does not work (is always zero) you might need to recompile the util/orca/orca executable and when used on Linux make sure that the appropriate user has permission to execute it.

data.py holds all of the code required to build and load the datasets we used. Our generated synthetic datasets can be found in the data folder. If you use these datasets in your code, you can use the function below, to get matching splits:

def load_graphs(filename, data_dir='data'):
  adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, same_sample, n_max = torch.load(f'{data_dir}/{filename}.pt')
  print(f'Dataset {filename} loaded from file')
  
  test_len = int(round(len(adjs)*0.2))
  train_len = int(round((len(adjs) - test_len)*0.8))
  val_len = len(adjs) - train_len - test_len
  print(f'Dataset sizes: train {train_len}, val {val_len}, test {test_len}')
  
  train, val, test = torch.utils.data.random_split(graphs, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(1234))

  return train, val, test

visualization.py holds the function used to produce the figures in the paper from the generated_graphs.pt data files generated by test.py.

Output files

Tensorboard logs, hyperparameters used and model checkpoints are saved in the logs folder. Running test.py on a trained model creates a test subfolder for the run with six more folders, each of which holds generated graphs in a generated_graphs.pt file generated when conditioning on nothing, true eigenvalues or true eigenvalues and eigenvectors. Graphs are stored as a list of tensors. This is repeated for the exponential moving average (EMA) model weights.

Example commands used to train our models

Planar:

python full_gan.py --batch_size 10 --max_epochs=12000 --log_every_n_steps=60 --check_val_every_n_epoch 90 --gen_gelu --disc_gelu --k_eigval 2 --eigvec_temp_decay --decay_eigvec_temp_over 2000 --min_eigvec_temp 0.8 --n_eigval_warmup_epochs 2000 --n_eigvec_warmup_epochs 2000 --eigval_temp_decay --decay_eigval_temp_over 2000 --min_eigval_temp 0.8 --SON_D_full_readout --noisy_gen --SON_normalize_left --lambda_gating --lambda_last_gating --lambda_upsample --SON_small --noisy_disc --derived_eigval_noise --normalize_noise --spectral_norm --eigvec_right_noise --gp_shared_alpha --no_restart --SON_gumbel_temperature_decay --decay_SON_gumbel_temp_over 10000 --SON_gumbel_temperature_warmup_epochs 0 --n_data_workers 4 --dataset planar --n_nodes 64 --gp_do_backwards --eigvec_sign_flip --ignore_first_eigv --gp_include_unpermuted --clip_grad_norm 1.0 --seed 0

SBM:

python full_gan.py --batch_size 5 --max_epochs=6000 --log_every_n_steps=60 --check_val_every_n_epoch 45 --gen_gelu --disc_gelu --k_eigval 4 --eigvec_temp_decay --decay_eigvec_temp_over 1000 --min_eigvec_temp 0.8 --n_eigval_warmup_epochs 1000 --n_eigvec_warmup_epochs 1000 --eigval_temp_decay --decay_eigval_temp_over 1000 --min_eigval_temp 0.8 --SON_D_full_readout --noisy_gen --SON_normalize_left --lambda_gating --lambda_last_gating --lambda_upsample --SON_small --noisy_disc --derived_eigval_noise --normalize_noise --spectral_norm --eigvec_right_noise --gp_shared_alpha --no_restart --SON_gumbel_temperature_decay --decay_SON_gumbel_temp_over 5000 --SON_gumbel_temperature_warmup_epochs 0 --n_data_workers 4 --dataset sbm --gp_do_backwards --eigvec_sign_flip --ignore_first_eigv --gp_include_unpermuted --clip_grad_norm 1.0

Proteins:

python full_gan.py --batch_size 1 --max_epochs=1020 --log_every_n_steps=60 --check_val_every_n_epoch 20 --gen_gelu --disc_gelu --k_eigval 16 --eigvec_temp_decay --decay_eigvec_temp_over 176 --min_eigvec_temp 0.8 --n_eigval_warmup_epochs 176 --n_eigvec_warmup_epochs 176 --eigval_temp_decay --decay_eigval_temp_over 176 --min_eigval_temp 0.8 --SON_D_full_readout --noisy_gen --SON_normalize_left --lambda_gating --lambda_last_gating --lambda_upsample --SON_small --noisy_disc --derived_eigval_noise --normalize_noise --spectral_norm --eigvec_right_noise --gp_shared_alpha --no_restart --SON_gumbel_temperature_decay --decay_SON_gumbel_temp_over 875 --SON_gumbel_temperature_warmup_epochs 0 --n_data_workers 4 --dataset protein --gp_do_backwards --eigvec_sign_flip --ignore_first_eigv --gp_include_unpermuted --clip_grad_norm 1.0 --accelerator 'ddp' --gpus 4 --seed 0

Community:

python full_gan.py --batch_size 10 --max_epochs=12000 --log_every_n_steps=60 --check_val_every_n_epoch 90 --gen_gelu --disc_gelu --k_eigval 2 --eigvec_temp_decay --decay_eigvec_temp_over 2000 --min_eigvec_temp 0.8 --n_eigval_warmup_epochs 2000 --n_eigvec_warmup_epochs 2000 --eigval_temp_decay --decay_eigval_temp_over 2000 --min_eigval_temp 0.8 --SON_D_full_readout --noisy_gen --SON_normalize_left --lambda_gating --lambda_last_gating --lambda_upsample --SON_small --noisy_disc --derived_eigval_noise --normalize_noise --spectral_norm --eigvec_right_noise --gp_shared_alpha --no_restart --SON_gumbel_temperature_decay --decay_SON_gumbel_temp_over 10000 --SON_gumbel_temperature_warmup_epochs 0 --n_data_workers 4 --dataset community --n_start 12 --n_end 21 --n_graphs 100 --gp_do_backwards --eigvec_sign_flip --ignore_first_eigv --gp_include_unpermuted --clip_grad_norm 1.0

QM9 (following MolGAN):

python full_gan.py --batch_size 128 --max_epochs=30 --log_every_n_steps=80 --check_val_every_n_epoch 1 --gen_gelu --disc_gelu --n_G 3 --n_D 3 --k_eigval 2 --eigvec_temp_decay --decay_eigvec_temp_over 10 --min_eigvec_temp 0.8 --n_eigval_warmup_epochs 10 --n_eigvec_warmup_epochs 15 --eigval_temp_decay --decay_eigval_temp_over 10 --min_eigval_temp 0.8 --SON_D_full_readout --noisy_gen --SON_normalize_left --lambda_gating --lambda_last_gating --lambda_upsample --SON_small --noisy_disc --derived_eigval_noise --normalize_noise --spectral_norm --eigvec_right_noise --gp_shared_alpha --no_restart --SON_gumbel_temperature_decay --decay_SON_gumbel_temp_over 25 --SON_gumbel_temperature_warmup_epochs 0 --n_data_workers 4 --dataset qm9 --n_graphs -1 --gp_do_backwards --eigvec_sign_flip --ignore_first_eigv --gp_include_unpermuted --clip_grad_norm 1.0

Adding --qm9_strict_eval flag during evaluation or training (for model selection) only counts molecules as valid if they have one connected component.

To train MolGAN* add --mlp_gen, to train GG-GAN* add --use_fixed_emb, to train GG-GAN (RS)* add --no_cond. These baselines only generate adjacencies, like SPECTRE trained with --adj_only flag.

The code has been simplified and refactored, so the results might slightly differ from the published ones.

The PPGN GAN can sometimes get stuck. If you encounter stability issues when training the model on your problem, you can try increasing the number of eigenvectors considered (--k_eigval), increase the permutations used for the gradient penalty (--gp_adj_rewire and --gp_adj_noise), use the (--cat_mult_eigvals) option or try a different random seed (--seed) as some can be unlucky.