Kolmogorov Arnold Networks with equinox [kanx]
stergiosba opened this issue · 1 comments
There seems to be a race to find the best implementation (currently 4) of KANs on the PyTorch side of the ML universe but there is a dearth of efforts on the JAX side (only 1 with flax). The flax implementation seems, well... slow to be frank and not tested.
I had a go at and made a pure equinox
based implementation that just works in 125 lines of code: KANX.
I hope others find it interesting and use it.
It includes:
- A
KANLayer
that can be stacked inside aneqx.nn.Sequential
as any other module. Tested with an MLP and it works (further testing needed to iron this but it should work) - Modified the MNIST CNN example code from the
equinox
website to test and report96.7%
accuracy on validation set with minimal tuning and it took 13 seconds.
Just for reference, the fastest PyTorch implementation takes about 180 seconds for the same validation accuracy (have to perform more head to head comparisons though).
Let me know what you think.
Thank you for your detailed description, it has been recorded.
Next time, you can submit a pull request.
Thanks again.