TriPlaneNet inverts an input image into the latent space of 3D GAN for novel view rendering.
This is the official repository that contains source code for the arXiv paper TriPlaneNet v1 & TriPlaneNet v2.
[Paper v2] [Paper v1] [Project Page] [Video]
If you find TriPlaneNet useful for your work please cite:
@article{bhattarai2024triplanenet,
title={TriPlaneNet: An Encoder for EG3D Inversion},
author={Bhattarai, Ananta R. and Nie{\ss}ner, Matthias and Sevastopolsky, Artem},
booktitle={IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
year={2024}
}
-
02.11.2023: Code for TriPlaneNet v2 is released. 🔥
-
29.10.2023: The second version of the article has been accepted to WACV 2024 and features additional contributions and better results. 🔥 Code for TriPlaneNet v2 is coming soon.
- Clone the repository and run the script to set up the environment:
git clone https://github.com/anantarb/triplanenet.git --recursive
source ./scripts/install_deps.sh
This will set up a conda environment triplanenet
with all dependencies. The script will also download the required pre-trained models and places them into the respective folders. For Basel Face Model 2009 (BFM09) that Deep3DFaceRecon_pytorch uses, get access to the model as described in Deep3DFaceRecon_pytorch and organize in the structure accordingly.
- (Optional) Install
ffmpeg
on the system if you want to run the model on videos.
Datasets are stored in a directory containing PNG/JPG files, a metadata file dataset.json
for labels, and the confidence maps given by a pre-trained network from unsup3d. Each label is a 25-length list of floating point numbers, which is the concatenation of the flattened 4x4 camera extrinsic matrix and flattened 3x3 camera intrinsic matrix. We provide an example of the dataset structure in dataset_preprocessing/ffhq/example_dataset/
. Training Dataset is pre-processed using the procedure as described in here. For inference, images should be pre-processed in a way that align with the training data following dataset preprocessing.
We release the following pre-trained models. TriPlaneNet model contains the entire TriPlaneNet architecture, including the encoder and decoder weights.
Path | Description |
---|---|
EG3D Inversion | TriPlaneNet trained with the (FFHQ dataset + synthesized EG3D samples) for EG3D inversion. |
EG3D Inversion | TriPlaneNet v2 trained with the (FFHQ dataset + synthesized EG3D samples) for EG3D inversion. |
If you wish to use the pretrained model for training or inference, you may do so using the flag --checkpoint_path
.
In addition, we provide various auxiliary models needed for training your own TriPlaneNet model from scratch as well as pretrained models needed for evaluation.
Path | Description |
---|---|
FFHQ EG3D | EG3D model pretrained on FFHQ taken from NVlabs with 512x512 output resolution. |
IR-SE50 Model | Pretrained IR-SE50 model taken from TreB1eN for use in our ID loss during TriPlaneNet training. |
CurricularFace Backbone | Pretrained CurricularFace model taken from HuangYG123 for use in ID similarity metric computation. |
MTCNN | Weights for MTCNN model taken from TreB1eN for use in ID similarity metric computation. (Unpack the tar.gz to extract the 3 model weights.) |
By default, we assume that all auxiliary models are downloaded and saved to the directory pretrained_models
. However, you may use your own paths by changing the necessary values in configs/path_configs.py
.
- Refer to
configs/paths_config.py
to define the necessary data paths and model paths for training and evaluation. For example, the dataset path insideconfigs/paths_config.py
should look like this:
dataset_paths = {
'test': '/path/to/dataset/train/',
'train': '/path/to/dataset/test/',
}
We provide an example how the dataset directory should be structured here.
- Don't forget to pre-process the data as described in here.
The main training script can be found in scripts/train_v2.py
.
Intermediate training results are saved to opts.exp_dir
. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs
.
Experiments can also be tracked with Weights & Biases. To enable Weights & Biases (wandb
), first make an account on the platform's webpage and install wandb
using pip install wandb
. Then, to train TriPlaneNet v2 using wandb
, simply add the flag --use_wandb
.
Training, for example, can be launched with the following command
python scripts/train_v2.py \
--exp_dir=/path/to/experiment \
--use_wandb
We provide a highly configurable implementation. See options/train_options_v2.py
for a complete list of the configuration options.
The main training script can be found in scripts/train.py
.
Intermediate training results are saved to opts.exp_dir
. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs
.
Experiments can also be tracked with Weights & Biases. To enable Weights & Biases (wandb
), first make an account on the platform's webpage and install wandb
using pip install wandb
. Then, to train TriPlaneNet using wandb
, simply add the flag --use_wandb
.
Training, for example, can be launched with the following command
python scripts/train.py \
--exp_dir=/path/to/experiment \
--device=cuda:0 \
--n_styles=14 \
--batch_size=4 \
--test_batch_size=4 \
--workers=8
--test_workers=8 \
--val_interval=2500 \
--save_interval=5000 \
--use_wandb
We provide a highly configurable implementation. See options/train_options.py
for a complete list of the configuration options.
Having trained your model, you can use scripts/inference_v2.py
to apply the model on a set of images.
python scripts/inference_v2.py \
--exp_dir=/path/to/experiment \
--checkpoint_path=experiment/checkpoints/best_model.pt \
--data_path=/path/to/test_data \
--test_batch_size=4 \
--test_workers=4 \
--shapes \
--novel_view_angles -0.3 0.3
You can use scripts/inference_video_v2.py
to apply the model on a video.
python scripts/inference_video_v2.py \
--exp_dir=/path/to/experiment \
--checkpoint_path=experiment/checkpoints/best_model.pt \
--data_path=/path/to/extracted_frames \
--test_batch_size=4 \
--test_workers=4 \
--frame_rate=30 \
--novel_view_angles -0.3 0.3
To extract frames from a video, refer to dataset preprocessing.
For more video inference options, see options/test_options_videos_v2.py
.
Having trained your model, you can use scripts/inference.py
to apply the model on a set of images.
python scripts/inference.py \
--exp_dir=/path/to/experiment \
--checkpoint_path=experiment/checkpoints/best_model.pt \
--data_path=/path/to/test_data \
--test_batch_size=4 \
--test_workers=4 \
--couple_outputs \
--resize_outputs \
--novel_view_angles -0.3 0.3
For challenging cases, you can also apply cascaded test-time refinement (CTTR) by simply adding the flag --CTTR
. For more inference options, see options/test_options.py
.
- We have merged the code of TriPlaneNet v2 and recommend running inference with TriPlaneNet v2 since it leads to better results.
- In the v1 paper, we report the final ID similarity metric for all methods by performing the normalization
(ID_SIM) * 0.5 + 0.5
. However, we drop the normalization in this code.
Our work builds on top of amazing open-source networks and codebases. We thank the authors for providing them.