While working on porting code here from pytorch to JAX & equinox I was curious to know whether the keypoint extractor as is could be trained on the CelebA dataset to infer keypoints on human faces
Turns out the answer is yes (these are evals, red are true keypoints and green dots are predicted)
Not bad for about 100 steps of training and 18000 split 8:2 between training and eval (takes about 10 minutes on a machine with a 3060 Nvidia GPU)
The interesting part of this is the Hourglass architecture with a convolution head
# virtual env is recommended
pip install -e .
python scripts/training.py
python scripts/inference.py some_human_face_picture.jpeg


