/BUAA-DL2021

BUAA-2021深度学习中作业

Primary LanguagePython

中作业任务说明

中作业将基于课程组开发的交通预测领域开源框架,要求同学们以两人一组(不能超过2人,不推荐单人组队)的形式选择一开源模型将其改写为符合框架接口约束的模型。

具体来说,课程组预先选定一批交通开源模型——涉及交通流量/速度/需求量预测、轨迹下一跳预测。并且,课程组前期已经为前述任务准备好了数据集并搭建了数据预处理、评估模块。因此,各小组的主要工作是将模型开源代码改写为符合框架抽象接口约束的模型类,虽然课程组已经实现了大部分通用的数据接口,但是部分模型可能需要对数据接口进行一定的修改

作业分数构成

中作业占总成绩的 10%(10分),由以下四部分组成:

  • 复现模型完成度(4分):要求复现的模型能够在框架中运行、代码风格良好、注释充分。代码风格规范要求见文档,要求使用第三方库flake8进行检查!
  • 技术报告(4分):具体要求参见技术报告模板。
  • 复现性能(2分):考量复现模型的性能。
  • 难易度加分(附加分):考虑到部分难度较高的模型,复现难度较高,因此完成中等难度模型复现工作(即能在框架中运行)的小组可获得 1 分额外加分,完成困难模型复现工作的小组获得 2 分额外加分。(注:若加分后总分超过 10 分,则按 10 分计算)。

框架介绍

框架以流水线的形式运行,主要分为五个步骤:

  1. 初始化流水线配置。(依托 Config 模块)
  2. 数据集加载与数据预处理,数据转换并划分训练集、验证集、测试集。(依托 Data 模块)
  3. 加载模型。(依托 Model 模块)
  4. 训练验证模型,并在测试集上进行测试。(依托 Executor 模块)
  5. 评估模型测试输出。(依托 Evaluator 模块)

各组的工作主要涉及 Model 模块,通过使用课程组预先构建的 Data 模块提供的输入数据进行任务预测,并输出符合课程组预先构建 Evaluator 模块的评估输入接口格式。对于 POI 轨迹下一跳预测任务与交通流量\速度预测任务具体接口格式的说明,可参见对应任务的说明 md 文件。

关于框架的具体介绍可以参考文档

框架使用

课程组预先构建了两个脚本文件,方便各组测试运行复现后的模型。

test_model.py

测试模型是否能跑通框架的脚本文件,各组可以参考该脚本文件在调式模式或命令行中测试运行模型,与通过实操完成对框架的深入理解。

run_model.py

训练模型并在测试集上进行预测,最后会将测试结果输出至命令行与 trafficdl/cache/evaluate_cache/ 文件夹下。

命令行运行示例:

python run_model.py --task traj_loc_pred --model DeepMove --dataset foursquare_tky

这里简单介绍部分常用命令行参数:

  • task:所要执行的任务名,默认为traj_loc_pred。需要各组修改为自己对应的任务名。
  • model:所要运行的模型名,默认为DeepMove。需要各组修改为自己对应的模板模型名。
  • dataset:所要运行的数据集,默认为 foursquare_tky
  • config_file:用户指定 config 文件名,默认为 None
  • saved_model:是否保存训练的模型结果,默认为 True
  • train:当模型已被训练时是否要重新训练,默认为 True

论文列表

POI 轨迹下一跳预测(POI 推荐)

模型名 难度 论文 开源代码
1 STRNN Predicting the Next Location: A Recurrent Model with Spatial and Temporal Contexts pytorch
2 LSTPM Where to Go Next: Modeling Long- and Short-Term User Preferences for Point-of-Interest Recommendation pytorch
3 GeoSAN Geography-Aware Sequential Location Recommendation pytorch
4 Flashback(RNN) Location Prediction over Sparse User Mobility Traces Using RNNs: Flashback in Hidden States pytorch
5 ATST-LSTM An Attention-based Spatiotemporal LSTM Network for Next POI Recommendation tensorflow
6 STAN STAN: Spatio-Temporal Attention Network for Next Location Recommendation pytorch
7 STF-RNN STF-RNN: Space Time Features-based Recurrent Neural Network for Predicting People Next Location keras
8 CARA A Contextual Attention Recurrent Architecture for Context-Aware Venue Recommendation keras

交通状态预测(流量、速度、需求量)

流量

编号 模型名 难度 论文 开源代码
9 ST-MetaNet Urban traffic prediction from spatio-temporal data using deep meta learning MXNet
10 STSGCN Spatial-Temporal Synchronous Graph Convolutional Networks: A New Framework for Spatial-Temporal Network Data Forecasting MXNet
11 DSAN Preserving Dynamic Attention for Long-Term Spatial-Temporal Prediction tf2
12 ST-GDN Traffic Flow Forecasting with Spatial-Temporal Graph Diffusion Network tf
13 STDN Revisiting spatial-temporal similarity: A deep learning framework for traffic prediction Keras
14 STFGNN Spatial-Temporal Fusion Graph Neural Networks for Traffic Flow Forecasting MXNet
15 STNN Spatio-Temporal Neural Networks for Space-Time Series Forecasting and Relations Discovery Pytorch
16 STAG-GCN Spatiotemporal Adaptive Gated Graph Convolution Network for Urban Traffic Flow Forecasting Pytorch
17 ST-CGA Spatial-Temporal Convolutional Graph Attention Networks for Citywide Traffic Flow Forecasting Keras
18 ResLSTM Deep Learning Architecture for Short-Term Passenger Flow Forecasting in Urban Rail Transit Keras
19 DGCN Dynamic Graph Convolution Network for Traffic Forecasting Based on Latent Network of Laplace Matrix Estimation Pytorch
20 Multi-STGCnet Multi-STGCnet: A Graph Convolution Based Spatial-Temporal Framework for Subway Passenger Flow Forecasting Keras
21 Conv-GCN Multi-Graph Convolutional Network for Short-Term Passenger Flow Forecasting in Urban Rail Transit Keras
22 TCC-LSTM-LSM A temporal-aware LSTM enhanced by loss-switch mechanism for traffic flow forecasting Keras
23 CRANN A Spatio-Temporal Spot-Forecasting Framework forUrban Traffic Prediction Pytorch

速度

编号 模型名 难度 论文 开源代码
24 BaiduTraffic Deep sequence learning with auxiliary information for traffic prediction tf
25 GMAN Gman: A graph multi-attention network for traffic prediction tf
26 MRA-BGCN Multi-Range Attentive Bicomponent Graph Convolutional Network for Traffic Forecasting Pytorch
27 FC-GAGA FC-GAGA: Fully Connected Gated Graph Architecture for Spatio-Temporal Traffic Forecasting tf
28 HGCN Hierarchical Graph Convolution Networks for Traffic Forecasting Pytorch
29 GTS Discrete Graph Structure Learning for Forecasting Multiple Time Series Pytorch
30 DKFN Graph Convolutional Networks with Kalman Filtering for Traffic Prediction Pytorch
31 GaAN GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs MXNet
32 ST-MGAT ST-MGAT: Spatial-Temporal Multi-Head Graph Attention Networks for Traffic Forecasting Pytorch
33 DGFN Dynamic Graph Filters Networks: A Gray-box Model for Multistep Traffic Forecasting tf2
34 ATDM On the Inclusion of Spatial Information for Spatio-Temporal Neural Networks Pytorch

需求

编号 模型名 难度 论文 开源代码
35 DMVST-Net Deep Multi-View Spatial-Temporal Network for Taxi Demand Prediction Keras
36 STG2Seq Stg2seq: Spatial-temporal graph to sequence model for multi-step passenger demand forecasting tf
37 CCRNN Coupled Layer-wise Graph Convolution for Transportation Demand Prediction Pytorch
38 SHARE Semi-Supervised Hierarchical Recurrent Graph Neural Network for City-Wide Parking Availability Prediction Pytorch
39 PVCGN Physical-Virtual Collaboration Modeling for Intra-and Inter-Station Metro Ridership Prediction Pytorch

数据集

  1. POI 轨迹下一跳预测使用 foursqaure-tky 数据集下载链接,Gowalla 数据集下载链接

请将下载好的数据集存放于 Bigscity-TrafficDL/raw_data 文件夹下。

  1. 交通状态预测数据集下载链接

数据集和模型的对应关系查看文件交通状态数据集和模型对应关系.xlsx

请将下载好的数据集存放于 code/raw_data/数据集名/数据集具体的文件 文件夹下。(直接下载数据集对应的文件夹,解压到code/raw_data/下即可。)