Code and data for the paper Multi-Size Patched Spatial-Temporal Transformer Network for Short- and Long-Term Grid-based Crowd Flow Prediction
Please cite the following paper if you use this repository in your research.
Under construction
This repo is for TaxiBJ, more information can be found in MSP-STTN.
PyTorch > 1.07
Please refer to requirements.txt
- Processing data according to MSP-STTN-DATA.
- The
data\
should be like this:
data
___ TaxiBJ
- Or the processed data can be downloaded from BAIDU_PAN,PW:
p3r0
.
- Several pre-trained models can be downloaded from BAIDU_PAN, PW:
9ius
. - The
model\
should be like this:
model
___ Imp_0547
___ ___ pre_model_ep_19.pth
___ Imp_0548
___ ___ pre_model_ep_41.pth
___ Imp_1543
___ ___ pre_model_ep_0.pth
___ ___ pre_model_it_14700.pth
___ Imp_1545
___ ___ pre_model_ep_23.pth
___ Imp_3548
___ ___ pre_model_ep_22.pth
___ Imp_3805
___ ___ pre_model_ep_22.pth
___ Imp_5547
___ pre_model_ep_27.pth
- Use
sh BEST.sh
for short-term prediction. - Use
sh BEST_long.sh
for short-term prediction.
- Use
sh TRAIN.sh
for short-term prediction. - Use
sh TRAIN_long.sh
for short-term prediction.
___ BEST_long.sh
___ BEST.sh
___ data # Data
___ dataset
___ model # Store the training weights
___ net # Network struture
___ pre_main_short.py # Main function for shot-term prediction
___ pre_setting_bj_long.yaml # Configuration for long-term prediction
___ pre_setting_bj.yaml # Configuration for short-term prediction
___ README.md
___ record # Recording the training and the test
___ TRAIN_long.sh
___ TRAIN.sh
___ util