Deep Learning and Optimization with JAX
Implementation of optimization algorithms with JAX and Numpy to solve deep learning problems.
image_deblurring
and image_inpainting
notebooks implements low-level solutions to Image Deblurring and Image Inpainting problems respectively without utilizing external libraries using JAX and Numpy.
optimization_algos
notebook implements a variety of optimization algorithms that are commonly used in Deep Learning to solve a Image Classification problem from scratch using JAX and Numpy:
- Stochastic Gradient Descent
- Stochastic Gradient Descent with Momentum + L2 Regularization
- ADAM
- L-BFGS
CelebA dataset was modified into a binary classification problem of whether an image is Male or Female - intentionally done to make this a very simplified problem as the aim is to gain deeper understanding of deep learning, optimization and its implementation with JAX and Numpy. Training is done on the first 15,000 images, testing on the last 5,000 images, and a Logistic Regression Classifier is able to achieve 95+ validation (left) and test (right) AUC