/BGRL

Primary LanguagePython

Large-Scale Representation Learning on Graphs via BootStrapping

This project provides an implementation of the Bootstrapped Graph Latents (BGRL) model, a graph representation learning method that learns by predicting alternative augmentations of the input. The model first trains each graph encoder in a fully unsupervised manner, computing embeddings for each node. Subsequently, a simple linear model is trained on top of these frozen embeddings for classification tasks.

Setup

  1. Clone the repository:

    git clone https://github.com/al3ssandrocaruso/BGRL.git
    cd BGRL
  2. Install dependencies:

    pip install -r requirements.txt

Available Datasets

  • WikiCS
  • Amazon_Computers
  • Amazon_Photos
  • Coauthor_CS
  • Coauthor_Physics
  • Cora

Running the Script

  1. With Default Dataset (Amazon_Photos): Simply run:

    python runner.py
  2. With a Specific Dataset: Use the --dataset argument to specify the dataset:

    python runner.py --dataset WikiCS

Command-Line Arguments

  • --hidden_dim_encoder: Dimension of hidden layers in the GCN encoder (default: 512)
  • --g_embedding_dim: Dimension of the embedding generated by the encoder (default: 256)
  • --hidden_dim_predictor: Dimension of hidden layers in the MLP predictor (default: 512)
  • --num_epochs: Number of epochs for training (default: 300)
  • --pf_view_1: Probability of feature perturbation for the first view (default: 0.3)
  • --pf_view_2: Probability of feature perturbation for the second view (default: 0.2)
  • --pe_view_1: Probability of edge perturbation for the first view (default: 0.3)
  • --pe_view_2: Probability of edge perturbation for the second view (default: 0.4)
  • --dataset: Dataset to use (default: Amazon_Photos)
  • --optimizer: Optimizer to use (adam or sgd, default: adam)
  • --lr: Learning rate (default: 1e-5)
  • --use_batch_norm: Use batch normalization (default: False)
  • --use_layer_norm: Use layer normalization (default: False)
  • --save_weights: Save the model weights after training (default: False)

Example:

python main.py --num_epochs 200 --dataset WikiCS --batch_norm --save_weights