This repository contains the PyTorch code for:
GaitGraph: Graph Convolutional Network for Skeleton-Based Gait Recognition
Torben Teepe, Ali Khan, Johannes Gilg, Fabian Herzog, Stefan Hörmann
- Python >= 3.6
- CUDA >= 10
First, create a virtual environment or install dependencies directly with:
pip3 install -r requirements.txt
The extraction of the pose data from CASIA-B can either run the commands bellow or download the preprocessed data using:
cd data
sh ./download_data.sh
Optional: If you choose to run the preprocessing, download the dataset and run the following commands.
# Download required weights
cd models
sh ./download_weights.sh
# Copy extraction script
# <PATH_TO_CASIA-B> should be something like: /home/ ... /datasets/CASIA_Gait_Dataset/DatasetB
cd ../data
cp extract_frames.sh <PATH_TO_CASIA-B>
cd <PATH_TO_CASIA-B>
mkdir frames
sh extract_frames.sh
cd frames
find . -type f -regex ".*\.jpg" -print | sort | grep -v bkgrd > ../casia-b_all_frames.csv
cp ../casia-b_all_frames.csv <PATH_TO_REPO>/data
cd <PATH_TO_REPO>/src
export PYTHONPATH=${PWD}:$PYTHONPATH
cd preparation
python3 prepare_detection.py <PATH_TO_CASIA-B> ../../data/casia-b_all_frames.csv ../../data/casia-b_detections.csv
python3 prepare_pose_estimation.py <PATH_TO_CASIA-B> ../../data/casia-b_detections.csv ../../data/casia-b_pose_coco.csv
python3 split_casia-b.py ../../data/casia-b_pose_coco.csv --output_dir ../../data
To train the model you can run the train.py
script. To see all options run:
cd src
export PYTHONPATH=${PWD}:$PYTHONPATH
python3 train.py --help
Check experiments/1_train_*.sh
to see the configurations used in the paper.
Optionally start the tensorboard with:
tensorboard --logdir=save/casia-b_tensorboard
Evaluate the models using evaluate.py
script. To see all options run:
python3 evaluate.py --help
Top-1 Accuracy per probe angle excluding identical-view cases for the provided models on CASIA-B dataset.
0 | 18 | 36 | 54 | 72 | 90 | 108 | 126 | 144 | 162 | 180 | mean | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
NM#5-6 | 85.3 | 88.5 | 91 | 92.5 | 87.2 | 86.5 | 88.4 | 89.2 | 87.9 | 85.9 | 81.9 | 87.7 |
BG#1-2 | 75.8 | 76.7 | 75.9 | 76.1 | 71.4 | 73.9 | 78 | 74.7 | 75.4 | 75.4 | 69.2 | 74.8 |
CL#1-2 | 69.6 | 66.1 | 68.8 | 67.2 | 64.5 | 62 | 69.5 | 65.6 | 65.7 | 66.1 | 64.3 | 66.3 |
The pre-trained model is available here.
GaitPose itself is released under the MIT License (see LICENSE).
The following parts of the code are borrowed from other projects. Thanks for their wonderful work!
- Object Detector: eriklindernoren/PyTorch-YOLOv3
- Pose Estimator: HRNet/HRNet-Human-Pose-Estimation
- ST-GCN Model: yysijie/st-gcn
- ResGCNv1 Model: yfsong0709/ResGCNv1
- SupCon Loss: HobbitLong/SupContrast
If you use GaitGraph, please use the following BibTeX entry.
@misc{teepe2021gaitgraph,
title={GaitGraph: Graph Convolutional Network for Skeleton-Based Gait Recognition},
author={Torben Teepe and Ali Khan and Johannes Gilg and Fabian Herzog and Stefan H\"ormann and Gerhard Rigoll},
year={2021},
eprint={2101.11228},
archivePrefix={arXiv},
primaryClass={cs.CV}
}