/ST-SSL_Dataset

The data of paper "Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction".

Datasets used in ST-SSL

We provide several datasets used in the ST-SSL framework, which leverages self-supervised learning for traffic flow prediction.

The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}.

Please use Git Large File Storage (LFS) to pull this repo to your computer.

You can also download the dataset at Beihang Cloud Drive or Google Drive.

Dataset Format

Each dataset is composed of 4 files, namely train.npz, val.npz, test.npz, and adj_mx.npz.

|----NYCBike1\
|    |----train.npz    # training data
|    |----adj_mx.npz   # predefined graph structure
|    |----test.npz     # test data
|    |----val.npz      # validation data

Train/Val/Test data is composed of 4 numpy.ndarray objects:

The train/val/test data is composed of 4 numpy.ndarray objects:

  • X: input data. It is a 4D tensor of shape (#samples, #lookback_window, #nodes, #flow_types), where # denotes the number sign.
  • Y: data to be predicted. It is a 4D tensor of shape (#samples, #predict_horizon, #nodes, #flow_types). Note that X and Y are paired in the sample dimension. For instance, (X_i, Y_i) is the i-the data sample with i indexing the sample dimension.
  • X_offset: a list indicating offsets of X's lookback window relative to the current time with offset 0.
  • Y_offset: a list indicating offsets of Y's prediction horizon relative to the current time with offset 0.

For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to forecast flows for the next time step.

adj_mx.npz is the graph adjacency matrix that indicates the spatial relation of every two regions/nodes in the studied area.

Dataset Usage

You can use the following code to view the data:

import numpy as np

data = np.load('./BJTaxi/train.npz')
for file in data.files:
    print(file, data[file].shape)

Raw Data

All datasets are processed by us as a sliding window view. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN.