The instant ngp in pure JAX still needs optimization. Currently, it is way slower than the CUDA version.
- Set up a virtual Python3 environment, then install necessary packages with this command
conda create --name jax_ngp python=3 conda activate jax_ngp pip install -r requirements.txt
- Or, just open Google Colab and run notebook_train.ipynb
[1] Mildenhall, Ben, et al. "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis." ECCV (2020).
[2] Müller, Thomas, et al. "Instant Neural Graphics Primitives with a Multiresolution Hash Encoding." SIGGRAPH (2022).