/multimodal_gaze_target_prediction

This repo provides the training and testing code for our paper "A Modular Multimodal Architecture for Gaze Target Prediction: Application to Privacy-Sensitive Settings" published at the GAZE workshop at CVPR 2022

Primary LanguagePython

Overview

This repo provides the training and testing code for our paper "A Modular Multimodal Architecture for Gaze Target Prediction: Application to Privacy-Sensitive Settings" published at the GAZE workshop at CVPR 2022. [paper] [video]

Setup

We use the GazeFollow and VideoAttentionTarget datasets for training and testing our models. Please download them at the following link provided by ejcgt/attention-target-detection:
GazeFollow extended: link
VideoAttentionTarget: link

Next, extract the pose and depth modalities for both datasets following the instructions in modality_extraction.md

After, please update the paths to the datasets in the config.py file.

We use pytorch for our experiments. Use the provided environment file to create the conda environment for the experiments.

conda env create -f environment.yml

Training

Training on GazeFollow

Step 1. Train the single modality models
python train_on_gazefollow.py --modality image --backbone_name efficientnet-b1 --log_dir <path>
python train_on_gazefollow.py --modality depth --backbone_name efficientnet-b0 --log_dir <path>
python train_on_gazefollow.py --modality pose --backbone_name efficientnet-b0 --log_dir <path>

The trained model weights will be saved in the specified log_dir.

Step 2. Initialize the weights for the attention model
python initialize_attention_model.py --image_weights <path> --depth_weights <path> --pose_weights <path> --attention_weights <path>

Provide the paths to the pretrained image, depth and pose models. The attention model with initialized weights will be saved in the path specified by the attention_weights argument.

Step 3. Train the attention model
python train_on_gazefollow.py --modality attention --init_weights <path> --log_dir <path>

Provide the path to the initialized attention model weights. The trained model weights will be saved in the specified log_dir.

Training on VideoAttentionTarget

Set pred_inout=True in the config.py file.

Train the single modality models
python train_on_videoatttarget.py --modality image --init_weights <path> --backbone_name efficientnet-b1 --log_dir <path>
python train_on_videoatttarget.py --modality depth --init_weights <path> --backbone_name efficientnet-b0 --log_dir <path>
python train_on_videoatttarget.py --modality pose --init_weights <path> --backbone_name efficientnet-b0 --log_dir <path>

Provide the initial weights from training on GazeFollow. The trained model weights will be saved in the specified log_dir.

Train the attention model
python train_on_videoatttarget.py --modality attention --init_weights <path> --log_dir <path>

Provide the initial weights from training on GazeFollow. The trained model weights will be saved in the specified log_dir.

Training the privacy-sensitive models

Simply set privacy=True in the config.py file. Then follow the same steps as above to train the respective models.

Testing

Testing on GazeFollow

python eval_on_gazefollow.py --model_weights <path> 

Provide the path to the model weights with the model_weights argument.

Testing on VideoAttentionTarget

python eval_on_videoatttarget.py --model_weights <path>

Provide the path to the model weights with the model_weights argument.

Pre-trained models

Pre-trained human-centric module: link
Pre-trained attention model on GazeFollow: link
Pre-trained attention model on VideoAttentionTarget: link

Citation

If you use our code, please cite:

@inproceedings{gupta2022modular,
  title={A Modular Multimodal Architecture for Gaze Target Prediction: Application to Privacy-Sensitive Settings},
  author={Gupta, Anshul and Tafasca, Samy and Odobez, Jean-Marc},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (CVPRW)},
  pages={5041--5050},
  year={2022}
}

References

Parts of the code have been adapted from ejcgt/attention-target-detection