Official pytorch code for Mantra: Memory augmented networks for multiple trajectory prediction - CVPR2020
To install the required packages, in a Python 3 environment just execute the following:
pip install -r requirements.txt
We provide a dataloader for the KITTI dataset in dataset_invariance.py. The dataloader yields samples of (past, future) trajectories paired with a semantic map of the surrounding scene.
To train MANTRA, first it is necessary to train the autoencoder, then to train the writing controller and finally to train the Iterative Refinment Module (IRM). Trainings can be monitored using tensorboard, logs are stored in the folder runs/(runs-pretrain/runs-createMem/runs-IRM). In the pretrained_model folder there are pretrained models of the different components (autoencoder, writing controller, MANTRA).
python train_ae.py
The autoencoder can be trained with the train_ae.py script. train_ae.py calls trainer_ae.py The model will be saved into the folder test/[current_date]. A pretrained model can be found in pretrained_models/model_AE/
python train_controllerMem.py --model pretrained_autoencoder_model_path
The writing controller for the memory with autoencoder can be trained with train_controllerMem.py. train_controllerMem.py calls trainer_controllerMem.py. The path of a pretrained autoencoder model has to be passed to the script (it defaults to the pretrained model we provided). A pretrained model (autoencoder + writing controller) can be found in pretrained_models/model_controller/
python train_IRM.py --model pretrained_autoencoder+controller_model_path
train_IRM.py calls trainer_IRM.py The script trains the IRM module that generates the final prediction based on the decoded trajectory and the context map. The paths of a pretrained autoencoder with writing controller model and populated memories have to be passed to the script (it defaults to the pretrained models we provided). A pretrained MANTRA model can be found in pretrained_models/model_complete/
python test.py --model pretrained_complete_model_path --withIRM True/False --saved_memory True/False
test.py calls evaluate_MemNet.py This script generates metrics on the KITTI dataset using a trained models. We compute Average Displacement Error (ADE) and Final Displacement Error (FDE, also referred to as Error@K or Horizon Error).
--cuda Enable/Disable GPU device (default=True).
--batch_size Number of samples that will be fed to MANTRA in one iteration (default=32).
--past_len Past length (default=20).
--future_len Future length (default=40).
--preds Number of predictions generated by MANTRA model (default=5)
--model Path of pretrained model for the evaluation (default='pretrained_models/MANTRA/model_MANTRA')
--visualize_dataset The system saves (in *folder_test/dataset_train* and *folder_test/dataset_test*) all examples
of dataset.
--saved_memory The system chooses which memories will be used in evaluation.
If True, it will be loaded memories from 'memories_path' folder.
If False, new memories will be generated. pairs of past-future will be decided by writing controller of model.
--memories_path This path will be used only if saved_memory flag is True.
--withIRM The model generates predictions with/without Iterative Refinement Module.
--saveImages The system saves in test folder examples of dataset with prediction generated by MANTRA.
If None, it doesn't save any qualitative examples but only quantitative results.
If 'All', it saves all examples.
If 'Subset', it saves examples defined in index_qualitative.py (hand picked most significant samples)
(default=None)
--dataset_file Name of json file cointaining the dataset (default='kitti_dataset.json')
--info Name of evaluation. It will use for name of the test folder (default='')
If you use our code or find it useful in your research, please cite the following paper:
@inproceedings{cvpr_2020, author = {Marchetti, Francesco and Becattini, Federico and Seidenari, Lorenzo and Del Bimbo, Alberto}, booktitle = {International Conference on Computer Vision and Pattern Recognition (CVPR)}, publisher = {IEEE}, title = {MANTRA: Memory Augmented Networks for Multiple Trajectory Prediction}, year = {2020} }
@ARTICLE{Geiger2013IJRR, author = {Andreas Geiger and Philip Lenz and Christoph Stiller and Raquel Urtasun}, title = {Vision meets Robotics: The KITTI Dataset}, journal = {International Journal of Robotics Research (IJRR)}, year = {2013} }
This source code is shared under the license CC-BY-NC-SA, please refer to the LICENSE file for more information.
This source code is only shared for R&D or evaluation of this model on user database.
Any commercial utilization is strictly forbidden.
For any utilization with a commercial goal, please contact contact_cs or bendahan