brentyi/tilted

Why jax?

hiyyg opened this issue · 1 comments

hiyyg commented

Thanks for the great work! I want to ask a question irrelevant to the paper:

why did you choose jax? Is it faster and easier to use than pytorch?

The answer here is not super exciting; I just enjoy working with JAX much more than I enjoy working with PyTorch. Most of my other projects are in PyTorch and I personally find JAX dramatically simpler to use, although this is a contentious topic. 🙂

As far as speed goes, my experience has been generally very positive—being forced to JIT everything results in generally faster code with good CPU/GPU parallelism, and things like the multi-GPU parallelism in the visualization code are easier implement.

That said, for NeRF stuff specifically I think the PyTorch ecosystem has large advantages. A lot of CUDA tools like tiny-cuda-nn and nerfacc are more readily available. Some NeRF codebases use boolean masking, which relies of dynamic shapes that are possible in PyTorch but not in JAX, and I do sometimes run into obscure JIT-related issues (example: jax-ml/jax#10332).