/STBGAT

Primary LanguagePython

Spatial-Temporal Bipartite Graph Attention Network for Traffic Forecasting (STBGAT)

This is the Pytorch implementation of the paper Spatial-Temporal Bipartite Graph Attention Network for Traffic Forecasting published at PAKDD'24

Architecture

Setup

We used python 3.10.11 and Pytorch 1.13.1 for the implementation.
Use the requirement file to create a new conda environment.

conda create -n stbgat python=3.10.11
conda activate stbgat
pip install -r requirements.txt

Training

We included the config files and pre-processed graphs for PEMS04, PEMS07 and PEMS08 datasets (Checkout relevant branch).
Before training it's important to checkout to the correct branch and use the correct config file and data files.
To train the model, use the following command. Before running python script, create the directory to save model output (See config file)

mkdir model_output_dir
python main.py

Evaluation

You can download the pretrained models from Google Drive
You need to place the model weight file at the 'model_input_path' specified in the config file before running the python script. To evaluate the model, use the following command.

python evaluate.py