Framework to use the same dataloader and evaluation methods to compare different Deep Learning Spatio-Temporal Traffic Prediction algorithms. In the future, better data augmentation and other evaluation methods will be added (e.g. test model robustness to sensor failure).
This repository contains the following implementations:
- DCRNN Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting
- Graph WaveNet Graph WaveNet for Deep Spatial-Temporal Graph Modeling
- GMAN Graph Multi-Attention Network for Traffic Prediction For citation see below.
This work is preliminary and will change in the next months!
In case precomputed graph embeddings are used, they have also been regenerated using the adjacency matrix and based on the information in the paper.
Now one sensor at location (+) is disabled and these predictions are compared against original predictions on the (validation) set.
Increasing model robustness involves the following:
- A change in observation should not influence predictions of far away sensors.
- Mitigate local influence of faulty measurements.
- Add DCRNN
- Add Graph Wavenet
- Add GMAN
- Add ST-GCN
- Add DCRNN-Pytorch
- Add PEMS-BAY dataset
- Create performance table
- Show robustness/inference statistics
Dependency can be installed using the following command:
pip install -r requirements.txt
The traffic data files for Los Angeles (METR-LA) and the Bay Area (PEMS-BAY), i.e., metr-la.h5
and pems-bay.h5
, are available at Google Drive or Baidu Yun, and should be
put into the data/{metr-la|pems-bay}/
directory.
The *.h5
files store the data in panads.DataFrame
using the HDF5
file format. Here is an example:
sensor_0 | sensor_1 | sensor_2 | sensor_n | |
---|---|---|---|---|
2018/01/01 00:00:00 | 60.0 | 65.0 | 70.0 | ... |
2018/01/01 00:05:00 | 61.0 | 64.0 | 65.0 | ... |
2018/01/01 00:10:00 | 63.0 | 65.0 | 60.0 | ... |
... | ... | ... | ... | ... |
Here is an article about Using HDF5 with Python.
# METR-LA
python dcrnn_test_pytorch.py --config_filename=data/metr-la/pretrained/dcrnn_test_pytorch.yaml
python gwnet_test.py --checkpoint data/metr-la/pretrained/graph_wavenet_repr.pth --data data/metr-la/metr-la.h5
python gwnet_test.py --lstm --nhid 256 --checkpoint data/metr-la/models/fc_lstm.pth --data data/metr-la/metr-la.h5
python gman_train.py --max_epoch 0 --SE_file data/metr-la/SE(METR-LA).txt --model_file data/metr-la/pretrained/GMAN_METR-LA --traffic_file data/metr-la/metr-la.h5
The generated prediction are stored in data/{metr-la|pems-bay}/results/
.
DCRNN requires pre-calculated road network distances, Graph Wavenet allows to compute them implicitly. The LSTM and the Transformer-based GMAN do not use road network distances.
The pairwise pre-calculated road network distances between sensors data/{metr-la|pems-bay}/distances.csv
are used to generate the adjacency matrix:
python -m scripts.gen_adj_mx --sensor_locations_filename=data/metr-la/graph_sensor_locations.csv --normalized_k=0.1\
--output_pkl_filename=data/metr-la/adj_mx.pkl
The world-coordinate locations of the sensors are available at data/{metr-la|pems-bay}/graph_sensor_locations.csv
.
To be added.
TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper:
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting, ICLR 2018.
PyTorch implementation of Graph WaveNet in the following paper:
[Graph WaveNet for Deep Spatial-Temporal Graph Modeling, IJCAI 2019] (https://arxiv.org/abs/1906.00121).
TensorFlow implementation of Graph Multi-Attention Network in the following paper:
Chuanpan Zheng, Xiaoliang Fan, Cheng Wang, and Jianzhong Qi. "GMAN: A Graph Multi-Attention Network for Traffic Prediction", AAAI2020 (https://arxiv.org/abs/1911.08415)
If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following papers:
@inproceedings{li2018dcrnn_traffic,
title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
booktitle={International Conference on Learning Representations (ICLR '18)},
year={2018}
}
@article{wu_graph_2019,
title = {Graph {WaveNet} for Deep Spatial-Temporal Graph Modeling},
url = {http://arxiv.org/abs/1906.00121},
author = {Wu, Zonghan and Pan, Shirui and Long, Guodong and Jiang, Jing and Zhang, Chengqi},
year = {2019},
}
@inproceedings{GMAN-AAAI2020,
author = {Chuanpan Zheng and Xiaoliang Fan and Cheng Wang and Jianzhong Qi}
title = {GMAN: A Graph Multi-Attention Network for Traffic Prediction},
booktitle = {AAAI},
year = {2020}
}