This repository contains the PyTorch implementation for paper TSA-Net: Tube Self-Attention Network for Action Quality Assessment (ACM-MM'21 Oral)
[arXiv] [supp] [slides] [poster] [video] [BibTex]
If this repository is helpful to you, please star it. If you find our work useful in your research, please consider citing:
@inproceedings{TSA-Net,
title={TSA-Net: Tube Self-Attention Network for Action Quality Assessment},
author={Wang, Shunli and Yang, Dingkang and Zhai, Peng and Chen, Chixiao and Zhang, Lihua},
booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
year={2021},
pages={4902–4910},
numpages={9}
}
In this repository, we open source the code of TSA-Net on FR-FS dataset. The initialization process is as follows:
# 1.Clone this repository
git clone https://github.com/Shunli-Wang/TSA-Net.git ./TSA-Net
cd ./TSA-Net
# 2.Create conda env
conda create -n TSA-Net python
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
# 3.Download pre-trained model and FRFS dataset. All download links are listed as follow.
# PATH/TO/rgb_i3d_pretrained.pt
# PATH/TO/FRFS
# 4.Create data dir
mkdir ./data && cd ./data
mv PATH/TO/rgb_i3d_pretrained.pt ./
ln -s PATH/TO/FRFS ./FRFS
After initialization, please check the data structure:
.
├── data
│ ├── FRFS -> PATH/TO/FRFS
│ └── rgb_i3d_pretrained.pt
├── dataset.py
├── train.py
├── test.py
...
Download links:
- FR-FS Dataset: You can download the FR-FS dataset (About 2.5 G) from BaiduNetDisk [star] or Google Drive
- rgb_i3d_pretrained.pt: I3D backbone pretrained on Kinetics (BaiduNetDisk [i3dm] or Google Drive) is used in our work, which is referenced from Gated-Spatio-Temporal-Energy-Graph.
- Tracking boxes for AQA-7 & MTL-AQA: Due to the ongoing work, we are sorry that we can't share the source code of MTL-AQA and AQA-7. We provide the original tracking boxes of AQA and MTL-AQA at BaiduNetDisk [6v51] or Google Drive.
We provide the training and testing code of TSA-Net and Plain-Net. The difference between the two is whether the TSA module exists. This option is controlled by --TSA
item.
python train.py --gpu 0 --model_path TSA-USDL --TSA
python test.py --gpu 0 --pt_w Exp/TSA-USDL/best.pth --TSA
python train.py --gpu 0 --model_path USDL
python test.py --gpu 0 --pt_w Exp/USDL/best.pth
Our code is adapted from MUSDL. We are very grateful for their wonderful implementation. All tracking boxes in our project are generated by SiamMask. We also sincerely thank them for their contributions.
If you have any questions about our work, please contact slwang19@fudan.edu.cn.