JAX implementation of Tensorial Radiance Fields, written as an exercise. This is a copy of the code from Brent Yi, which we used for the course CS294: Geometry and Learning for 3D Vision in Spring 2022 at UC Berkeley.
@misc{TensoRF,
title={TensoRF: Tensorial Radiance Fields},
author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su},
year={2022},
eprint={2203.09517},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
We don't attempt to reproduce the original paper exactly, but can achieve decent results after 5~10 minutes of training:
-
Download
nerf_synthetic
dataset: Google Drive. With the default training script arguments, we expect this to be extracted to./data
, eg./data/nerf_synthetic/lego
. -
Install dependencies. Probably you want the GPU version of JAX; see the official instructions. Then:
pip install -r requirements.txt
-
To print training options:
python ./train_lego.py --help
-
To monitor training, we use Tensorboard:
tensorboard --logdir=./runs/
-
To generate some renders, visit
./render.ipynb
in Jupyter. Or:python ./render_360.py --help
Things aren't totally matched to the official implementation:
- The official implementation relies heavily on masking operations to improve runtime (for example, by using a weight threshold for sampled points). These require dynamic shapes and are currently difficult to implement in JAX, so we replace them with workarounds like weighted sampling.
- Several training details that would likely improve performance are not yet implemented: bounding box refinement, ray filtering, regularization, etc.
- We include mixed-precision training, which can speed training throughput up by a significant factor. (is this actually faster in terms of wall-clock time? unclear)
Implementation details are based loosely on the original PyTorch implementation apchsenstu/TensoRF.
unixpickle/learn-nerf and google-research/jaxnerf were also really helpful for understanding core NeRF concepts + connecting them to JAX!
- Blender dataloading
- Main implementation
- Point sampling
- Feature MLP
- Rendering
- VM decomposition
- Basic implementation
- Vectorized
- Training
- Learning rate scheduler
- ADAM + grouped LR
- Exponential decay
- Reset decay after upsampling
- Running
- Checkpointing
- Logging
- Loss
- PSNR
- Test metrics
- Test images
- Ray filtering
- Bounding box refinement
- Incremental upsampling
- Regularization terms
- Learning rate scheduler
- Performance
- Weight thresholding for computing appearance features
- per ray top-k
- global top-k (bad & deleted)
- Mixed-precision
- implemented
- stable
- Multi-GPU (should be quick)
- Weight thresholding for computing appearance features
- Rendering
- RGB
- Depth (median)
- Depth (mean)
- Batching
- Generate some GIFs
- Misc engineering
- Actions
- Understand vmap performance differences (details)