Rebuild the whole system on top of JAX
tbenthompson opened this issue · 0 comments
tbenthompson commented
In the sunny light of the future, it's no longer the right choice to write raw CUDA for a library like this. This whole system could be replaced with 1/10th the amount of code by rewriting it in JAX. JAX is basically numpy for GPUs. The new JAX-based cutde (or whatever you want to rename it!) would be:
- much simpler
- probably faster
- cross platform, working on CPU, GPU, TPU all for one piece of code.
I have no intention of ever doing this. I work on other stuff now. But, I thought I'd put this issue here in case anyone stumbles on it!
I could also potentially be paid to make this kind of upgrade.