/TensoRF_Jax

Implementation of TensoRF in Jax

Primary LanguageJupyter Notebook

tensorf-jax

JAX implementation of Tensorial Radiance Fields, written as an exercise. This is a copy of the code from Brent Yi, which we used for the course CS294: Geometry and Learning for 3D Vision in Spring 2022 at UC Berkeley.

@misc{TensoRF,
      title={TensoRF: Tensorial Radiance Fields},
      author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su},
      year={2022},
      eprint={2203.09517},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

We don't attempt to reproduce the original paper exactly, but can achieve decent results after 5~10 minutes of training:

Lego rendering GIF

Instructions

  1. Download nerf_synthetic dataset: Google Drive. With the default training script arguments, we expect this to be extracted to ./data, eg ./data/nerf_synthetic/lego.

  2. Install dependencies. Probably you want the GPU version of JAX; see the official instructions. Then:

    pip install -r requirements.txt
  3. To print training options:

    python ./train_lego.py --help
  4. To monitor training, we use Tensorboard:

    tensorboard --logdir=./runs/
  5. To generate some renders, visit ./render.ipynb in Jupyter. Or:

    python ./render_360.py --help

Differences from the PyTorch implementation

Things aren't totally matched to the official implementation:

  • The official implementation relies heavily on masking operations to improve runtime (for example, by using a weight threshold for sampled points). These require dynamic shapes and are currently difficult to implement in JAX, so we replace them with workarounds like weighted sampling.
  • Several training details that would likely improve performance are not yet implemented: bounding box refinement, ray filtering, regularization, etc.
  • We include mixed-precision training, which can speed training throughput up by a significant factor. (is this actually faster in terms of wall-clock time? unclear)

References

Implementation details are based loosely on the original PyTorch implementation apchsenstu/TensoRF.

unixpickle/learn-nerf and google-research/jaxnerf were also really helpful for understanding core NeRF concepts + connecting them to JAX!

To-do

  • Blender dataloading
  • Main implementation
    • Point sampling
    • Feature MLP
    • Rendering
    • VM decomposition
      • Basic implementation
      • Vectorized
  • Training
    • Learning rate scheduler
      • ADAM + grouped LR
      • Exponential decay
      • Reset decay after upsampling
    • Running
    • Checkpointing
    • Logging
      • Loss
      • PSNR
      • Test metrics
      • Test images
    • Ray filtering
    • Bounding box refinement
    • Incremental upsampling
    • Regularization terms
  • Performance
    • Weight thresholding for computing appearance features
      • per ray top-k
      • global top-k (bad & deleted)
    • Mixed-precision
      • implemented
      • stable
    • Multi-GPU (should be quick)
  • Rendering
    • RGB
    • Depth (median)
    • Depth (mean)
    • Batching
    • Generate some GIFs
  • Misc engineering
    • Actions
    • Understand vmap performance differences (details)