yklcs/jaxsplat

JAX JIT does not work

Closed this issue · 1 comments

JAX JIT is not supported due to a dynamically shaped array.
ops::project::fwd::xla computes the number of intersections between tiles and Gaussians (num_intersects).
This value is used in ops::rasterize::fwd::xla as the length of a sorted map of intersects to Gaussians (gaussian_ids).
gaussian_ids is also needed in ops::rasterize::bwd::xla, which is where the problem arises: a dynamically shaped array needs to be present in JAX-side code.


gsplat's kernels will probably need to be changed to solve this.
Some discussion on possible solutions: nerfstudio-project/gsplat#175

I think the easiest way to fix this right now is to simply recalculate gaussian_ids in ops:rasterize::bwd::xla.
This means that ops::rasterize::fwd::xla will not need to return gaussian_ids, and ops:rasterize::bwd::xla will only need statically sized inputs.

This does mean duplicate computations will happen on the forward and backward pass though.