This is the Pytorch implementation of the paper Spatial-Temporal Bipartite Graph Attention Network for Traffic Forecasting published at PAKDD'24
We used python 3.10.11 and Pytorch 1.13.1 for the implementation.
Use the requirement file to create a new conda environment.
conda create -n stbgat python=3.10.11
conda activate stbgat
pip install -r requirements.txt
We included the config files and pre-processed graphs for PEMS04, PEMS07 and PEMS08 datasets (Checkout relevant branch).
Before training it's important to checkout to the correct branch and use the correct config file and data files.
To train the model, use the following command. Before running python script, create the directory to save model output (See config file)
mkdir model_output_dir
python main.py
You can download the pretrained models from Google Drive
You need to place the model weight file at the 'model_input_path' specified in the config file before running the python script.
To evaluate the model, use the following command.
python evaluate.py