Observation about new Kernel Operator
adam-hartshorne opened this issue · 3 comments
adam-hartshorne commented
I notice the new kernel operator makes use of nested for loops and update operations. For JAX that is a very bad idea. for loops should be avoided at all costs.
AndPotap commented
Thanks for pointing this out. I'll think about how to do this for it to also be compatible with PyTorch.
mfinzi commented
I think we would just want to add a xnp.scan function to the backends and then use that. Also for cases where we are using a xnp.for_loop right now, we should probably replace with scan where possible