JAX JIT does not work
Closed this issue · 1 comments
JAX JIT is not supported due to a dynamically shaped array.
ops::project::fwd::xla
computes the number of intersections between tiles and Gaussians (num_intersects
).
This value is used in ops::rasterize::fwd::xla
as the length of a sorted map of intersects to Gaussians (gaussian_ids
).
gaussian_ids
is also needed in ops::rasterize::bwd::xla
, which is where the problem arises: a dynamically shaped array needs to be present in JAX-side code.
gsplat
's kernels will probably need to be changed to solve this.
Some discussion on possible solutions: nerfstudio-project/gsplat#175
I think the easiest way to fix this right now is to simply recalculate gaussian_ids
in ops:rasterize::bwd::xla
.
This means that ops::rasterize::fwd::xla
will not need to return gaussian_ids
, and ops:rasterize::bwd::xla
will only need statically sized inputs.
This does mean duplicate computations will happen on the forward and backward pass though.