wilson-labs/cola

Observation about new Kernel Operator

adam-hartshorne opened this issue · 3 comments

I notice the new kernel operator makes use of nested for loops and update operations. For JAX that is a very bad idea. for loops should be avoided at all costs.

that's a good point, think we can replace these with a scan @AndPotap ?

Thanks for pointing this out. I'll think about how to do this for it to also be compatible with PyTorch.

I think we would just want to add a xnp.scan function to the backends and then use that. Also for cases where we are using a xnp.for_loop right now, we should probably replace with scan where possible