mintisan/awesome-kan

Kolmogorov Arnold Networks with equinox [kanx]

stergiosba opened this issue · 1 comments

There seems to be a race to find the best implementation (currently 4) of KANs on the PyTorch side of the ML universe but there is a dearth of efforts on the JAX side (only 1 with flax). The flax implementation seems, well... slow to be frank and not tested.

I had a go at and made a pure equinox based implementation that just works in 125 lines of code: KANX.
I hope others find it interesting and use it.

It includes:

  • A KANLayer that can be stacked inside an eqx.nn.Sequential as any other module. Tested with an MLP and it works (further testing needed to iron this but it should work)
  • Modified the MNIST CNN example code from the equinox website to test and report 96.7% accuracy on validation set with minimal tuning and it took 13 seconds.

Just for reference, the fastest PyTorch implementation takes about 180 seconds for the same validation accuracy (have to perform more head to head comparisons though).

Let me know what you think.

Thank you for your detailed description, it has been recorded.

Next time, you can submit a pull request.

Thanks again.