FaceFlow is a PyTorch Lightning-based repository simplifying the creation of models for detecting facial biomechanics through Facial Action Units.
It allows easy storage of configurations as code, enabling reproducibility for each experiment.
The regular model can use any backbone, eg ConvNext.
You can improve domain generalization by adding unsupervised data with FixMatch model (lib/fixmatch.py
), as well as by using MIRO training procedure (lib/miro.py
).
The models train best on images cropped to faces, and code for cropping is in notebooks/preprocess_disfa.ipynb
.
You can also track your experiments on Weights And Biases - they have a free tier!. The guide on how to do it is coming later, but basically to enable it, run wandb login
and enter API Key when prompted. You also need to register your data as artifacts.
- Prepare data in the format of AU1,AU2,...,filename. In case of DISFA use
notebooks/preprocess_disfa.ipynb
- Modify training parameters in the
params
folder. These files there are provided as examples that you can modify to change datasets, architectures, or hyperparameters. - To train a model by running
python3 src/train.py
- Evaluate the model by running
python3 src/test.py
- Export your model to ONNX with
python3 src/export.py
or do some inference withpython3 src/infer.py
The easiest way to run the repository is by using nvidia-docker as a glorified virtualenv:
docker build --tag faceflow .
<- run only once
docker run -it --rm --gpus all -v /path/to/repo:/home -v /path/to/data:/data --shm-size=4gb faceflow
This way you can edit files locally and immediately use them inside docker.
Configuration:
- datamodule: provide action units you want to train on and location of data
- model: edit backbone model and hyperparameters. You can also use FixMatchModel and MIRO in place of the regular AUModel.
- trainer: edit training schedule: epochs, monitoring, devices
Model Variants:
- Regular AUModel uses AUDataModule (
lib/data/datamodules/vanilla.py
) - MIRO model uses AUDataModule
- FixMatchModel and DeFixMatchModel use SSLDataModule (
lib/data/datamodules/ssl.py
)
The model (to be released later) is trained on the DISFA Dataset.
The dataset consists of 27 videos of different people making facial expressions, one video for each person.
You can request the data from its authors here (only for research purposes).
Labels that go into the model must be in a csv file with columns AU1,AU2,...,filename.
All the necessary preprocessing can be done with notebooks/preprocess_disfa.ipynb
.
datamodule
is a LightningDataModule that abstracts away datasets and corresponding dataloaders.model
is a LightningModule wrapper that assembles the model together with its training loop.backbone
is a timm-compatible feature extractorheads_partials
are thefunctools.partial
instances ofcore.AUHead
that encapsulate the task logic, complete with the prediction head, the loss function and the final activation. Each head needs the backbone output size to be fully instantiated.optimizer_partial
is atorch.optim
or atimm
optimizer. It needs model parameters to get instantiated.scheduler_partial
is atimm
scheduler. It needs an optimizer on init.
trainer
is a Trainer that handles the entire training loop, including checkpoints, logging, accelerators, precision, etc.ckpt_path
refers to a Lightning checkpoint from which the training should be resumed.
The current params
folder is given as an example of what's possible to use, eg on Disfa dataset.
If you want to make a pull request, please revert any changes you made to the params
during the experiments, unless you want to modify the examples.
Git command
git rm -r params
git checkout upstream/main params
git commit -m "reverting params"
- examples on how to use WandB (~Dec '23)
- training examples for using unsupervised data (~Dec '23)
- release model trained on DISFA (~Dec '23)
- train and publish models on other datasets (~Q1'24)
How to make WandB use AWS credentials from EC2 IAM role
cmd = 'TOKEN=`curl -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600"` && curl -H "X-aws-ec2-metadata-token: $TOKEN" -v http://169.254.169.254/latest/meta-data/iam/security-credentials/AmazonS3FullAccess'
out = subprocess.run(cmd, shell=True, capture_output=True)
creds = json.loads(out.stdout.decode("utf8"))
os.environ["AWS_ACCESS_KEY_ID"] = creds["AccessKeyId"]
os.environ["AWS_SECRET_ACCESS_KEY"] = creds["SecretAccessKey"]
os.environ["AWS_SESSION_TOKEN"] = creds["Token"]
Made with 💖+☕ by TensorSense