/SSAT-for-Motion-Prediction

The repo for Semi-supervised Semantics-guided Adversarial Training for Robust Trajectory Prediction (ICCV 2023)

Primary LanguagePythonMIT LicenseMIT

SSAT-for-Motion-Prediction

Install the required packages

conda create --name ssat python=3.7
conda activate ssat
conda install pytorch==1.5.1 torchvision cudatoolkit=10.2 -c pytorch # 

# install argoverse api
pip install  git+https://github.com/argoai/argoverse-api.git

#or from their website: https://github.com/argoverse/argoverse-api#installation

# install others dependancy
pip install scikit-image IPython tqdm ipdb

Prepare the data.

Download the dataset Argoverse 1 and store them under the folder ./dataset

Run the training script

python train_adv_ssat.py -m _ssat_model --resume=basenet.ckpt --att_pattern=ade

Note

For the sake of simplicity, the uploaded version generates the adversarial trajectories during runtime, which is quite slow. In fact, you could add several lines to store the generated adversarial trajectories and load them during adversarial training.