/Jax-Volumetric-Renderer

A Jax based renderer for plenoxels

Primary LanguagePythonGNU Affero General Public License v3.0AGPL-3.0

Jax-Volumetric-Renderer

A Jax based renderer for plenoxels, work in progress...

This is an ongoing project on Plenoxels, for didactic purposes, from this paper: https://alexyu.net/plenoxels/

I am attempting to make a fast renderer in Jax, but using only a 8gb GPU. The full models as are 512x512x512, but we render only 128x128x128, more is possible but requires more batching and potentially some other tricks like voxel partitionning.

Usage:

  • The main work loop is in renderer_script.py, you will have to replace the relative paths so that the script can find the weights to the model.
  • The best is to run in the interpreter, there are a few calls to Open3d just to visualize what is happening, but actually not necessary.
  • The end result is in the end in an image buffer.

Other small things I've done to make things fit in memory:

  • Work in float16
  • Render one channel at a time
  • Find the silhouette of the model to be rendered

Things I've tried but did not work out:

  • Used Octrees, they just don't work with Jax due to memory coalescence.

Need to download the model weights from here: https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1

To have all the requirements installed, you need:

  • Jax
  • Opencv
  • Numpy
  • Open3d, just to visualize some stuff.

Tips to install Jax:

  • Follow the steps on the website: https://jax.readthedocs.io/en/latest/installation.html
  • If you work in Pycharm, first launch the virtual env from the terminal, then launch Pycharm from the same terminal
  • If you call both TF2 or Torch as you call Jax (or even import them) you will likely run out of memory. They don't like each other.

TODO:

  • Figure out where the random black dots are coming from on the render
  • Provide install script
  • Make python install wheel
  • Work on the optimization/training steps
  • Get a GPU with more VRAM

Bloopers: