A Jax based renderer for plenoxels, work in progress...
This is an ongoing project on Plenoxels, for didactic purposes, from this paper: https://alexyu.net/plenoxels/
I am attempting to make a fast renderer in Jax, but using only a 8gb GPU. The full models as are 512x512x512, but we render only 128x128x128, more is possible but requires more batching and potentially some other tricks like voxel partitionning.
- The main work loop is in renderer_script.py, you will have to replace the relative paths so that the script can find the weights to the model.
- The best is to run in the interpreter, there are a few calls to Open3d just to visualize what is happening, but actually not necessary.
- The end result is in the end in an image buffer.
- Work in float16
- Render one channel at a time
- Find the silhouette of the model to be rendered
- Used Octrees, they just don't work with Jax due to memory coalescence.
Need to download the model weights from here: https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1
- Jax
- Opencv
- Numpy
- Open3d, just to visualize some stuff.
- Follow the steps on the website: https://jax.readthedocs.io/en/latest/installation.html
- If you work in Pycharm, first launch the virtual env from the terminal, then launch Pycharm from the same terminal
- If you call both TF2 or Torch as you call Jax (or even import them) you will likely run out of memory. They don't like each other.
- Figure out where the random black dots are coming from on the render
- Provide install script
- Make python install wheel
- Work on the optimization/training steps
- Get a GPU with more VRAM