/CaST

PyTorch implementation of CaST, NeurIPS-23.

Primary LanguagePython

Deciphering Spatio-Temporal Graph Forecasting: A Causal Lens and Treatment

This repo provides the implementation code corresponding to our NeurIPS-23 paper entitled Deciphering Spatio-Temporal Graph Forecasting: A Causal Lens and Treatment. The code is implemented on Pytorch 1.10.2 on a server with NVIDIA RTX A6000.

image

Description

We present CaST, a new framework that takes a causal look into Spatio-Temporal Graph (STG) forecasting, tackling temporal out-of-distribution issues and dynamic spatial causation. We employ two causal tools: Back-door adjustment, implemented through a disentanglement block to distinguish invariant parts from temporal environments, and Front-door adjustment, which introduces a surrogate variable to emulate node information filtered based on their causal relationships.

Requirements

CaST uses the following dependencies:

  • Pytorch 1.10.2 and its dependencies
  • Numpy and Scipy
  • CUDA 11.3 or latest version, cuDNN

Dataset

Overview

The performance of CaST was validated using three datasets: PEMS08, AIR-BJ, and AIR-GZ. AIR-BJ and AIR-GZ contain one-year PM$_{2.5}$ readings obtained from air quality monitoring stations located in Beijing and Guangzhou, respectively. PEMS08 contains traffic flow data that was collected by sensors deployed on the road network. Traffic flow data is often considered to be a complex and challenging type of spatio-temporal data due to the numerous factors that can impact it, such as weather, time of day, and road conditions.

For proper execution, please ensure that the datasets are placed within the .\data\[dataset_name]\dataset.npy. Ensure that the datasets adhere to the following structure: (num_samples, num_nodes, input_dim).

For the PEMS08 dataset, dataset.npy file can be generated using the following code

data = np.load('./data/PEMS08/pems08.npz')['data']
np.save('./data/PEMS08/dataset.npy', data)

Edge Features

For detailed information on how we create edge attributes, please refer to Appendix D of our paper, where we provide an extensive discussion and introduction on it. Additionally, you may customize the edge attribute creation by implementing your own method, as an alternative to the Pearson correlation or the Time-delayed Dynamic Time Warping (DTW) method used in our study.

If you prefer to follow our approach, here is an example code to generate the peacor_adj.npy file:

def get_peacor_adj(data_path, threshold, save=False):
    # Load the dataset
    data = np.load(data_path + 'train.npz')['data']
    print("Data shape:", data.shape)
    
    # Compute the Pearson correlation coefficient matrix
    peacor = torch.corrcoef(torch.Tensor(data[...,0]).permute(1, 0))
    
    # Apply threshold
    peacor[peacor < threshold] = 0
    peacor[torch.eye(peacor.shape[0], dtype=bool)] = 0

    # Normalize the coefficients
    nonzero_peacor = peacor[peacor != 0]
    p_min, p_max = nonzero_peacor.min(), nonzero_peacor.max()
    peacor[peacor != 0] = (nonzero_peacor - p_min) / (p_max - p_min)

    # Visualization
    plt.figure(dpi=100)
    sns.heatmap(peacor)
    plt.show()
    
    # Save the result
    if save:
        np.save(data_path + 'peacor_adj.npy', peacor)

For reproducibility, we also provide peacor_adj.npy, sparse_adj.npy, and dist_adj.npy in the .\data\PEMS08\ directory for reference.

Arguments

We introduce some major arguments of our main function here.

Training settings:

  • mode: indicating the mode, e.g., train or test
  • gpu: using which GPU to train our model
  • seed: the random seed for experiments
  • dataset: which dataset to run
  • base_lr: the learning rate at the beginning
  • lr_decay_ratio: the ratio of learning rate decay
  • batch_size: training or testing batch size
  • seq_len: the length of historical steps
  • horizon: the length of future steps
  • input_dim: the dimension of inputs
  • output_dim: the dimension of inputs
  • max_epochs: the maximum of training epochs
  • patience: the patience of early stopping
  • save_preds: whether to save prediction results
  • train_ratio: the training ratio
  • val_ratio: the evaluastion ratio

Model hyperparameters:

  • hid_dim: the hidden dimensions in CaST
  • dropout: dropout rate
  • n_envs: the number of environments
  • node_embed_dim: the dimensionality of node embeddings
  • K: the depth of HL Deconfounder block

Training and Evaluation

The following examples are conducted on the datasets:

# PEMS08
python ./experiments/cast/main.py --dataset PEMS08 --mode 'train' --hid_dim 64 --n_envs 20 --node_embed_dim 5 --K 2
# AIR-BJ
python ./experiments/cast/main.py --dataset AIR_BJ --mode 'train' --hid_dim 64 --n_envs 10 --node_embed_dim 10 --K 3
# AIR-GZ
python ./experiments/cast/main.py --dataset AIR_GZ --mode 'train' --hid_dim 64 --n_envs 20 --node_embed_dim 5 --K 2

Code Reference

HL-HGAT: https://github.com/JH-415/HL-HGAT

VQVAE: https://github.com/ritheshkumar95/pytorch-vqvae

Citation

If you find our work useful in your research, please cite:

@article{xia2023deciphering,
  title={Deciphering Spatio-Temporal Graph Forecasting: A Causal Lens and Treatment},
  author={Xia, Yutong and Liang, Yuxuan and Wen, Haomin and Liu, Xu and Wang, Kun and Zhou, Zhengyang and Zimmermann, Roger},
  journal={arXiv preprint arXiv:2309.13378},
  year={2023}
}