A deep learning library with reverse-mode automatic differentiation implemented from scratch in Python with Numpy.
As a student in CMU’s 11-785: Introduction to Deep Learning course, I gained hands-on experience coding various deep learning primitives through explicit forward and backward passes. The following summer, I was hired as a TA for the course and was responsible for the optional Autograd 1/2 assignments, where students would implement automatic differentiation for a small subset of deep learning modules. Although, I initially opted out of this assignment due to a busy semester, my role as a lead TA exposed me to the elegance and efficiency of reverse-mode automatic differentiation for computing gradients. Fascinated by its simplicity and inspired by a desire to reinforce my learning, I decided to re-implement the deep learning models I had previously coded, this time using reverse-mode automatic differentiation for backpropagation.
This turned out to be a highly rewarding experience. Not only did I gain a deeper understanding of automatic differentiation and its core principles, but I also solidified my knowledge of computational graph construction, gradient flow, and optimization techniques. Revisiting these topics from the ground up gave me a renewed appreciation for how abstracted concepts in high-level frameworks are built from first principles.
- This library was built and tested with:
numpy==1.24.3
- The tests were built with:
torch==2.2.2
attention.py
is a PyTorch implementation of a single attention mechanism and was built withtorch==2.2.2
PuruTorch
utils.py
: ContainsContextManager
class for managing state context during forward calls,Functional
superclass which all diffrentiableTensor
operations inherit from and a helper function for unbroadcasting gradients.tensor.py
: Contains theTensor
class andParameter
sub-class.functional.py
: Contains diffrentiable primitive operations with bindings to theTensor
class.nn
module.py
: ContainsModule
class, a superclass to various nn primitives. Manages train/eval state and tracks learnable parameters.act_functional.py
: Contains some activation functionals.activation.py
: ContainsIdentity
,Sigmoid
,ReLU
,Tanh
andSoftmax
activation modules.loss_functional.py
: Contains some loss functionals.loss.py
: ContainsLoss
superclass andMSELoss
,CrossEntropyLoss
andCTCLoss
.batchnorm1d.py
: ContainsBatchNorm1D
Module.batchnorm2d.py
: ContainsBatchNorm2D
Module.dropout_functional.py
: Contains some dropout functionals.dropout.py
: ContainsDropout
,Dropout1D
andDropout2D
modules.linear.py
: ContainsLinear
module.conv_functional.py
: Contains some conv functionals.pool.py
: ContainsMaxPool1D
,MeanPool1D
,MaxPool2D
,MeanPool2D
modules.conv1d.py
: ContainsConv1D
module.conv2d.py
: ContainsConv2D
module.convtranspose1d.py
: ContainsConvTranspose1D
module.convtranspose2d.py
: ContainsConvTranspose2D
module.rnncell.py
: ContainsRNNCell
module.grucell.py
: ContainsGRUCell
module.CTCDecoding.py
: ContainsGreedySearchDecoder
andBeamSearchDecoder
, numpy implementations of greedy search and beam search algorithms to decode the predicted sequences from a model.attention.py
: ContainsAttention
, a PyTorch implementation of a single attention mechanism.
models
MLP.py
: A simpleMLP
model built with this library.CNN.py
: A simpleCNN
model built with this library.ResBlock
: AResBlock
Module built with this library.RNNClassifier.py
: AnRNNClassifier
model built with this library.GRUClassifier.py
: AGRUClassifier
model built with this library.
optim
optimizer.py
: ContainsOptimizer
superclass responsible for updating model parameters after gradient computation and zero-ing out gradients.sgd.py
: ContainsSGD
Optimizer.adam.py
: ContainsAdam
Optimizer.adamW.py
: ContainsAdamW
Optimizer.
tests
: Contains tests for correctness and can serve as examples of how to use the library. All tests can be run withpython tests/runner.py
. The following parts of the library have been sucessfully tested against equivalent PyTorch implementations for correctness:
Summary:
+--------------------------------------------+-------------+
| Test | Score |
+--------------------------------------------+-------------+
| Functional: Add Forward/Backward | 1 |
| Functional: Neg Forward/Backward | 1 |
| Functional: Sub Forward/Backward | 1 |
| Functional: Mul Forward/Backward | 1 |
| Functional: Div Forward/Backward | 1 |
| Functional: Pow Forward/Backward | 1 |
| Functional: Transpose Forward/Backward | 1 |
| Functional: Reshape Forward/Backward | 1 |
| Functional: Squeeze Forward/Backward | 1 |
| Functional: Unsqueeze Forward/Backward | 1 |
| Functional: Matmul Forward/Backward | 1 |
| Functional: Slice Forward/Backward | 1 |
| Functional: Log Forward/Backward | 1 |
| Functional: Exp Forward/Backward | 1 |
| Functional: Sum Forward/Backward | 1 |
| Functional: Mac Forward/Backward | 1 |
| Functional: Mean Forward/Backward | 1 |
| Functional: Var Forward/Backward | 1 |
| Activation: Identity Forward/Backward | 1 |
| Activation: Sigmoid Forward/Backward | 1 |
| Activation: ReLU Forward/Backward | 1 |
| Activation: Tanh Forward/Backward | 1 |
| Activation: Softmax Forward/Backward | 1 |
| Loss: MSELoss Forward/Backward | 1 |
| Loss: CELoss Forward/Backward | 1 |
| Loss: CTCLoss Forward/Backward | 1 |
| Reg: Batchnorm1d Forward/Backward | 1 |
| Reg: Batchnorm2d Forward/Backward | 1 |
| Reg: Dropout Forward/Backward | 1 |
| Reg: Dropout1D Forward/Backward | 1 |
| Reg: Dropout2D Forward/Backward | 1 |
| Resampling: Upsample1D Forward/Backward | 1 |
| Resampling: Downsample1D Forward/Backward | 1 |
| Resampling: Upsample2D Forward/Backward | 1 |
| Resampling: Downsample2D Forward/Backward | 1 |
| Pooling: MaxPool1D Forward/Backward | 1 |
| Pooling: MeanPool1D Forward/Backward | 1 |
| Pooling: MaxPool2D Forward/Backward | 1 |
| Pooling: MeanPool2D Forward/Backward | 1 |
| Layer: Linear Forward/Backward | 1 |
| Layer: Conv1D Forward/Backward | 1 |
| Layer: Conv2D Forward/Backward | 1 |
| Layer: ConvTranspose1D Forward/Backward | 1 |
| Layer: ConvTranspose2D Forward/Backward | 1 |
| Layer: RNNCell Forward/Backward | 1 |
| Layer: GRUCell Forward/Backward | 1 |
| Optim: SGD Step | 1 |
| Optim: Adam Step | 1 |
| Optim: AdamW Step | 1 |
| Model: MLP Forward/Backward | 1 |
| Model: CNN Forward/Backward | 1 |
| Model: ResBlock Forward/Backward | 1 |
| Model: RNNClassifier Forward/Backward | 1 |
| Model: GRUClassifier Forward/Backward | 1 |
| Decoding: Greedy Search (Numpy impl.) | 1 |
| Decoding: Beam Search (Numpy impl.) | 1 |
| Attention Forward/Backward (Pytorch impl.) | 1 |
+--------------------------------------------+-------------+
| TOTAL | 57/57 |
+--------------------------------------------+-------------+
Usage is similar to PyTorch! Here are some examples:
import PuruTorch as pt
# initialize 3 Tensors
x = pt.Tensor.random.uniform(0.0, 1.0, size=(5,6), requires_grad=True)
y = pt.Tensor.random.uniform(0.0, 1.0, size=(6,5), requires_grad=True)
z = pt.Tensor.random.uniform(0.0, 1.0, size=(5,), requires_grad=True)
# calculations
out = x @ y
out += z
# initiate backprop
out.backward()
# get gradients
print(x.grad)
print(y.grad)
print(z.grad)
from PuruTorch import Tensor, Module
from PuruTorch.nn import Identity, ReLU, BatchNorm2D, Conv2D, CrossEntropyLoss
from PuruTorch.optim import AdamW
import numpy as np
# Create a ResBlock
class ConvBn2D(Module):
def __init__(self, in_channels:int, out_channels:int,
kernel_size:int, stride:int=1, padding:int=1) -> None:
super().__init__()
self.conv = Conv2D(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding)
self.bn = BatchNorm2D(num_features=out_channels)
def forward(self, x : Tensor) -> Tensor:
out = self.bn(self.conv(x))
return out
class ResBlock(Module):
'''
A Residual block built with PuruTorch library.
'''
def __init__(self, in_channels:int, out_channels:int,
kernel_size:int, stride:int=3, padding:int=1):
super().__init__()
self.cbn1 = ConvBn2D(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding)
self.cbn2 = ConvBn2D(out_channels, out_channels, kernel_size=kernel_size,
stride=1, padding=padding)
if stride != 1 or in_channels != out_channels or kernel_size != 1 or padding != 0:
self.residual = ConvBn2D(in_channels, out_channels, kernel_size=1,
stride=stride, padding=1)
else:
self.residual = Identity()
self.act = ReLU()
def forward(self, x : Tensor) -> Tensor:
identity = x
out = self.act(self.cbn1(x))
out = self.cbn2(out)
out = self.residual(identity) + out
out = self.act(out)
return out
# create input/target Tensor's from numpy array
batch, in_c, in_width, out_c, out_width = 10, 24, 128, 3, 26
npx = np.random.randn(batch, in_c, in_width, in_width)
npy = np.random.randn(batch, out_c, out_width, out_width)
x = Tensor.tensor(npx, requires_grad=True)
y = Tensor.tensor(npy, requires_grad=True)
# init model
model = ResBlock(in_channels=in_c, out_channels=out_c,
kernel_size=5, stride=5, padding=5//2)
#loss function and optimizer
criterion = CrossEntropyLoss(reduction='mean')
optim = AdamW(model.parameters(), lr=0.001, betas=[0.9, 0.99], weight_decay=0.01)
# pass input through model
z = model(x)
# get loss
loss = criterion(z, y)
# backpropagate
loss.backward()
# step
optim.step()
# reset gradients to zero
optim.zero_grad()
You can also call the Tensor.bfs
on any node Tensor
which will kick start a breadth-first traversal of the computational graph and print the node Tensor
's shape, the Functional
used to create it, and the input Tensor
s to that Functional
. For ex, calling Tensor.bfs(loss)
in the above example would print:
{% raw %}
Depth: 0
____________________
| out_shape: ()
| op: <PuruTorch.nn.loss_functional.SoftmaxCrossEntropy object at 0x7f7f30a6e310>
| inp1_shape: (10, 3, 26, 26)
____________________
* Depth: 1
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.nn.act_functional.ReLU object at 0x7f7f30a6e250>
| inp1_shape: (10, 3, 26, 26)
____________________
** Depth: 2
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Add object at 0x7f7f30a6e190>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (10, 3, 26, 26)
____________________
*** Depth: 3
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Add object at 0x7f7f30a6e0d0>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
*** Depth: 3
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Add object at 0x7f7f30a67fd0>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
**** Depth: 4
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Mul object at 0x7f7f30a6bf10>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
**** Depth: 4
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a6bfd0>
| inp1_shape: (3,)
____________________
**** Depth: 4
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Mul object at 0x7f7f30a67e50>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
**** Depth: 4
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a67f10>
| inp1_shape: (3,)
____________________
***** Depth: 5
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Div object at 0x7f7f30a6bd90>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
***** Depth: 5
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a6be50>
| inp1_shape: (3,)
____________________
***** Depth: 5
____________________
| out_shape: (3,)
| op: None
____________________
***** Depth: 5
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Div object at 0x7f7f30a67cd0>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
***** Depth: 5
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a67d90>
| inp1_shape: (3,)
____________________
***** Depth: 5
____________________
| out_shape: (3,)
| op: None
____________________
****** Depth: 6
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Sub object at 0x7f7f30a6ba60>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
****** Depth: 6
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Pow object at 0x7f7f30a6bca0>
| inp1_shape: (1, 3, 1, 1)
____________________
****** Depth: 6
____________________
| out_shape: (3,)
| op: None
____________________
****** Depth: 6
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Sub object at 0x7f7f30a679a0>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
****** Depth: 6
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Pow object at 0x7f7f30a67be0>
| inp1_shape: (1, 3, 1, 1)
____________________
****** Depth: 6
____________________
| out_shape: (3,)
| op: None
____________________
******* Depth: 7
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.nn.conv_functional.Downsample2D object at 0x7f7f30a6b310>
| inp1_shape: (10, 3, 130, 130)
____________________
******* Depth: 7
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a53460>
| inp1_shape: (3,)
____________________
******* Depth: 7
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Add object at 0x7f7f30a6bbe0>
| inp1_shape: (1, 3, 1, 1)
| inp2_shape: (1,)
____________________
******* Depth: 7
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.nn.conv_functional.Downsample2D object at 0x7f7f30a67250>
| inp1_shape: (10, 3, 26, 26)
____________________
******* Depth: 7
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a530d0>
| inp1_shape: (3,)
____________________
******* Depth: 7
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Add object at 0x7f7f30a67b20>
| inp1_shape: (1, 3, 1, 1)
| inp2_shape: (1,)
____________________
******** Depth: 8
____________________
| out_shape: (10, 3, 130, 130)
| op: <PuruTorch.nn.conv_functional.Conv2D_stride1 object at 0x7f7f30a6b280>
| inp1_shape: (10, 24, 130, 130)
| inp2_shape: (3, 24, 1, 1)
| inp3_shape: (3,)
____________________
******** Depth: 8
____________________
| out_shape: (3,)
| op: <PuruTorch.functional.Mean object at 0x7f7f30a6b430>
| inp1_shape: (10, 3, 26, 26)
____________________
******** Depth: 8
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a6bb20>
| inp1_shape: (3,)
____________________
******** Depth: 8
____________________
| out_shape: (1,)
| op: None
____________________
******** Depth: 8
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.nn.conv_functional.Conv2D_stride1 object at 0x7f7f30a671c0>
| inp1_shape: (10, 3, 30, 30)
| inp2_shape: (3, 3, 5, 5)
| inp3_shape: (3,)
____________________
******** Depth: 8
____________________
| out_shape: (3,)
| op: <PuruTorch.functional.Mean object at 0x7f7f30a67370>
| inp1_shape: (10, 3, 26, 26)
____________________
******** Depth: 8
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a67a60>
| inp1_shape: (3,)
____________________
******** Depth: 8
____________________
| out_shape: (1,)
| op: None
____________________
********* Depth: 9
____________________
| out_shape: (10, 24, 130, 130)
| op: <PuruTorch.nn.conv_functional.Pad2D object at 0x7f7f30a6b0d0>
| inp1_shape: (10, 24, 128, 128)
____________________
********* Depth: 9
____________________
| out_shape: (3, 24, 1, 1)
| op: None
____________________
********* Depth: 9
____________________
| out_shape: (3,)
| op: None
____________________
********* Depth: 9
____________________
| out_shape: (3,)
| op: <PuruTorch.functional.Var object at 0x7f7f30a6b4c0>
| inp1_shape: (10, 3, 26, 26)
____________________
********* Depth: 9
____________________
| out_shape: (10, 3, 30, 30)
| op: <PuruTorch.nn.conv_functional.Pad2D object at 0x7f7f30a5afd0>
| inp1_shape: (10, 3, 26, 26)
____________________
********* Depth: 9
____________________
| out_shape: (3, 3, 5, 5)
| op: None
____________________
********* Depth: 9
____________________
| out_shape: (3,)
| op: None
____________________
********* Depth: 9
____________________
| out_shape: (3,)
| op: <PuruTorch.functional.Var object at 0x7f7f30a67400>
| inp1_shape: (10, 3, 26, 26)
____________________
********** Depth: 10
____________________
| out_shape: (10, 24, 128, 128)
| op: None
____________________
********** Depth: 10
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.nn.act_functional.ReLU object at 0x7f7f30a5af10>
| inp1_shape: (10, 3, 26, 26)
____________________
*********** Depth: 11
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Add object at 0x7f7f30a5ae50>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
************ Depth: 12
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Mul object at 0x7f7f30a5acd0>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
************ Depth: 12
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a5ad90>
| inp1_shape: (3,)
____________________
************* Depth: 13
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Div object at 0x7f7f30a5ab50>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
************* Depth: 13
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a5ac10>
| inp1_shape: (3,)
____________________
************* Depth: 13
____________________
| out_shape: (3,)
| op: None
____________________
************** Depth: 14
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.functional.Sub object at 0x7f7f30a5a820>
| inp1_shape: (10, 3, 26, 26)
| inp2_shape: (1, 3, 1, 1)
____________________
************** Depth: 14
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Pow object at 0x7f7f30a5aa60>
| inp1_shape: (1, 3, 1, 1)
____________________
************** Depth: 14
____________________
| out_shape: (3,)
| op: None
____________________
*************** Depth: 15
____________________
| out_shape: (10, 3, 26, 26)
| op: <PuruTorch.nn.conv_functional.Downsample2D object at 0x7f7f30a5a070>
| inp1_shape: (10, 3, 128, 128)
____________________
*************** Depth: 15
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a49d00>
| inp1_shape: (3,)
____________________
*************** Depth: 15
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Add object at 0x7f7f30a5a9a0>
| inp1_shape: (1, 3, 1, 1)
| inp2_shape: (1,)
____________________
**************** Depth: 16
____________________
| out_shape: (10, 3, 128, 128)
| op: <PuruTorch.nn.conv_functional.Conv2D_stride1 object at 0x7f7f30a53f70>
| inp1_shape: (10, 24, 132, 132)
| inp2_shape: (3, 24, 5, 5)
| inp3_shape: (3,)
____________________
**************** Depth: 16
____________________
| out_shape: (3,)
| op: <PuruTorch.functional.Mean object at 0x7f7f30a5a1c0>
| inp1_shape: (10, 3, 26, 26)
____________________
**************** Depth: 16
____________________
| out_shape: (1, 3, 1, 1)
| op: <PuruTorch.functional.Reshape object at 0x7f7f30a5a8e0>
| inp1_shape: (3,)
____________________
**************** Depth: 16
____________________
| out_shape: (1,)
| op: None
____________________
***************** Depth: 17
____________________
| out_shape: (10, 24, 132, 132)
| op: <PuruTorch.nn.conv_functional.Pad2D object at 0x7f7f30a53520>
| inp1_shape: (10, 24, 128, 128)
____________________
***************** Depth: 17
____________________
| out_shape: (3, 24, 5, 5)
| op: None
____________________
***************** Depth: 17
____________________
| out_shape: (3,)
| op: None
____________________
***************** Depth: 17
____________________
| out_shape: (3,)
| op: <PuruTorch.functional.Var object at 0x7f7f30a5a250>
| inp1_shape: (10, 3, 26, 26)
____________________