TorchKAN introduces a simplified KAN model and its variations, including KANvolver and KAL-Net, designed for high-performance image classification by leveraging polynomial transformations for enhanced feature detection.
- Overview of the Simplified KAN Model
- KANvolver: Monomial Basis Functions for MNIST Image Classification
- KAL-Net: Utilizing Legendre Polynomials in Kolmogorov Arnold Legendre Networks
- KAC-Net: Utilizing Chebyshev Polynomials
- Noise Analysis Using a Generalized KAN Model: For curve fitting or regression problems, use the
nKAN.py
script as a general structure.
This project showcases the training, validation, and quantization of the KAN model using PyTorch with CUDA acceleration. The torchkan
model is evaluated on the MNIST dataset, demonstrating significant accuracy improvements.
The KAN model has demonstrated promising outcomes across various Generative Additive Models (GAMs) since the 1980s. Inspired by a range of sources, this initial implementation of KAN
in torchkan.py
achieves over 97% accuracy with an evaluation time of 0.6 seconds. The quantized model further reduces this time to under 0.55 seconds on the MNIST dataset within 8 epochs, utilizing an Nvidia RTX 4090 on Ubuntu 22.04.
Current Understanding: While there is considerable hype around KANs, it's important to recognize that learning weights for activation functions (MLPs) and the activation functions themselves are established ideas. The extent of interpretability, scalability, quantizability, or efficiency they offer remains unclear. However, quantizability does not seem to be an issue, as the quantized evaluation on the base model leads to only ~0.6% drop in test performance.
Note: As the model is still under research, further exploration into its full potential is ongoing. Contributions, questions, and critiques are welcome. Constructive feedback and contributions are appreciated, and merge requests will be processed promptly, with a clear outline of the issue, the solution, and its effectiveness.
Note: The PyPI pipeline is currently deprecated and will be stabilized following the release of Version 1.
The KANvolver
model is a specialized neural network designed for classifying images from the MNIST dataset. It achieves an accuracy of ~99.56% with a minimal error rate of 0.18%. This model combines convolutional neural networks (CNNs) with polynomial feature expansions, effectively capturing both simple and complex patterns.
I am conducting large-scale analysis to investigate how KANs can be made more interpretable.
Thanks to @cometscome for writing this version in Julia: https://github.com/cometscome/FluxKAN.jl
Convolutional Feature Extraction: The model begins with two convolutional layers, each paired with ReLU activation and max-pooling. The first layer employs 16 filters of size 3x3, while the second increases the feature maps to 32 channels.
Polynomial Feature Transformation: After feature extraction, the model applies polynomial transformations up to the n-th order to the flattened convolutional outputs, enhancing its ability to discern non-linear relationships.
How Monomials Work: In this model, monomials are polynomial powers of the input features. By computing monomials up to a specified order, the model captures non-linear interactions between the features, potentially leading to richer and more informative representations for downstream tasks.
For a given input image, the monomials of its flattened pixel values are computed and then used to adjust the output of linear layers before activation. This approach introduces an additional dimension of feature interaction, allowing the network to learn more complex patterns in the data.
- Input Reshaping: Images are reshaped from vectors of 784 elements to 1x28x28 tensors for CNN processing.
- Feature Extraction: Spatial features are extracted and pooled through convolutional layers.
- Polynomial Expansion: Features undergo polynomial expansion to capture higher-order interactions.
- Linear Processing: The expanded features are processed by linear layers with normalization and activation.
- Output Generation: The network produces logits for each digit class in MNIST.
The KANvolver
model's 99.5% accuracy on MNIST underscores its robustness in leveraging CNNs and polynomial expansions for effective digit classification. While it shows significant potential, the model remains open for further adaptation and exploration in broader image processing challenges. Here are the results:
Note that KANvolver uses polynomials that are distinct from the original KANs[1].
KANs seem to handle noise better compared to MLPs for functional approximation. This requires further investigation.
To reproduce the results, use the nKAN.py
script.
The KAL_Net
represents the Kolmogorov Arnold Legendre Network (KAL-Net), a GAM architecture using Legendre polynomials to surpass traditional polynomial approximations like splines in KANs.
- Polynomial Order: Utilizes Legendre polynomials up to a specific order for each input normalization, capturing nonlinear relationships more efficiently than simpler polynomial approximations.
- Efficient Computations: By leveraging
functools.lru_cache
, the network avoids redundant computations, enhancing the forward pass's speed. - Activation Function: Employs the SiLU (Sigmoid Linear Unit) for improved performance in deeper networks due to its non-monotonic nature.
- Layer Normalization: Stabilizes each layer's output using layer normalization, enhancing training stability and convergence speed.
- Weight Initialization: Weights are initialized using the Kaiming uniform distribution, optimized for linear nonlinearity, ensuring a robust start for training.
- Dynamic Weight and Normalization Management: Manages weights for base transformations and polynomial expansions dynamically, scaling with input features and polynomial order.
- Flexibility in High-Dimensional Spaces: Legendre polynomials offer a more systematic approach to capturing interactions in high-dimensional data compared to splines, which often require manual knot placement and struggle with dimensionality issues.
- Analytical Efficiency: The caching and recurrence relations in Legendre polynomial computations minimize the computational overhead associated with spline evaluations, especially in high dimensions.
- Generalization: The orthogonal properties of Legendre polynomials typically lead to better generalization in machine learning model fitting, avoiding common overfitting issues with higher-degree splines.
- Accuracy:
KAL_Net
achieved a remarkable 97.8% accuracy on the MNIST dataset, showcasing its ability to handle complex patterns in image data. - Efficiency: The average forward pass takes only 500 microseconds, illustrating the computational efficiency brought by caching Legendre polynomials and optimizing tensor operations in PyTorch.
Ensure the following are installed on your system:
- Python (version 3.9 or higher)
- CUDA Toolkit (compatible with your PyTorch installation's CUDA version)
- cuDNN (compatible with your installed CUDA Toolkit)
Tested on MacOS and Linux.
Clone the torchkan
repository and set up the project environment:
git clone https://github.com/1ssb/torchkan.git
cd torchkan
pip install -r requirements.txt
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
To monitor experiments and model performance with wandb:
- Set Up wandb Account:
- Sign up or log in at Weights & Biases.
- Locate your API key in your account settings.
- Initialize wandb in Your Project:
Before running the training script, initialize wandb:
wandb login
Enter your API key when prompted to link your script executions to your wandb account.
- Adjust the Entity Name in
mnist.py
to Your Username (default is1ssb
)
python mnist.py
This script trains the model, validates it, quantizes it, and logs performance metrics using wandb.
For inquiries or support, please contact: Subhransu.Bhattacharjee@anu.edu.au
If this project is used in your research or referenced for baseline results, please use the following BibTeX entry.
@misc{torchkan,
author = {Subhransu S. Bhattacharjee},
title = {{TorchKAN}: Simplified {KAN} Model with Variations},
year = {2024},
howpublished = {\url{https://github.com/1ssb/torchkan/}}
}
Contributions are welcome. Please raise issues as needed. Maintained solely by @1ssb.
- [0] Ziming Liu et al., "KAN: Kolmogorov-Arnold Networks", 2024, arXiv. https://arxiv.org/abs/2404.19756
- [1] https://github.com/KindXiaoming/pykan
- [2] https://github.com/Blealtan/efficient-kan