/DATTT

[CVPR 2024] Depth-aware Test-Time Training for Zero-shot Video Object Segmentation

Primary LanguagePython

Depth-aware Test-Time Training for Zero-shot Video Object Segmentation




Weihuang Liu 1Xi Shen 2  Haolun Li 1   Xiuli Bi 3   Bo Liu 3Chi-Man Pun *,1Xiaodong Cun *,4

1 University of Macau     2 Intellindust 
3 Chongqing University of Posts and Telecommunications   4 Tencent AI Lab

CVPR 2024

Overview

teaser

Mainstream solutions mainly focus on learning a single model on large-scale video datasets, which struggle to generalize to unseen videos. We introduce Depth-aware test-time training (DATTT) to address the problem. Our key insight is to enforce the model to predict consistent depth during the TTT process. During the test-time training, the model is required to predict consistent depth maps for the same video frame under different data augmentation. The model is progressively updated and provides more precise mask prediction.

Pipeline

framework

We add a depth decoder to commonly used two-stream ZSVOS architecture to learn 3D knowledge. The model is first trained on large-scale datasets for object segmentation and depth estimation. Then, for each test video, we employ photometric distortion-based data augmentation to the frames. The error between the predicted depth maps is backward to update the image encoder. Finally, the new model is applied to infer the object.

Environment

This code was implemented with Python 3.6 and PyTorch 1.10.0. You can install all the requirements via:

pip install -r requirements.txt

Quick Start

  1. Download the YouTube-VOS dataset, DAVIS-16 dataset, FBMS dataset, Long-Videos dataset, MCL dataset, and SegTrackV2 dataset. You could get the processed data provided by HFAN. The depth maps are obtained by MonoDepth2, We also provide the processed data here.
  2. Download the pre-trained Mit-b1 or Swin-Tiny backbone.
  3. Training:
python train.py --config ./configs/train_sample.yaml
  1. Evaluation:
python ttt_demo.py --config configs/test_sample.yaml --model model.pth --eval_type base
  1. Test-time training:
python ttt_demo.py --config configs/test_sample.yaml --model model.pth --eval_type TTT-MWI

We provide our checkpoints here.

Citation

If you find this useful in your research, please consider citing:

@inproceedings{
title={Depth-aware Test-Time Training for Zero-shot Video Object Segmentation},
author={Weihuang Liu, Xi Shen, Haolun Li, Xiuli Bi, Bo Liu, Chi-Man Pun, Xiaodong Cun},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024}
}

Acknowledgements

EVP code borrows heavily from EVP, Swin and SegFormer. We thank the author for sharing their wonderful code.