A compilation of implementations of various ML papers, especially in computer vision. This contains some self-implementations and unofficial & official implementations. More to be added.


$ pip install torch-modules-compilation

Table of Contents


Bottleneck Residual Block


Your basic bottleneck residual block in ResNets. Image from the paper "Deep Residual Learning for Image Recognition"


in_channels (int): number of input channels

bottleneck_channels (int): number of bottleneck channels; usually less than the number of bottleneck channels

dropout (float): dropout rate; performed after every convolution


from torch_modules_compilation import modules

x = torch.randn(32, 256, 16, 16) # (batch_size, channels, height, width)
block = modules.BottleneckResBlock(in_channels=256, bottleneck_channels=64)

block(x).shape # (32, 256, 16, 16)

Depthwise Seperable Convolution


A depthwise seperable convolution; consists of a depthwise convolution and a pointwise convolution. Used in MobileNets and used in the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications". Image also from this paper.


in_channels (int): Number of input channels

out_channels (int): Number of output channels

kernel_size (int): Size of depthwise convolution kernel

stride (int): Stride of depthwise convolution


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.DepthwiseSepConv(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)

block(x).shape # (32, 128, 16, 16)

SAGAN self-attention module


A feature map self-attention module used in SAGAN; "Self-Attention Generative Adversarial Networks". Image also from this paper. This code implementation was copied and modified from https://github.com/rosinality/sagan-pytorch/blob/master/model.py#L82 under Apache 2.0 License. Modification removes spectral initalization.


in_channels (int): Number of input channels


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.FeatureMapSelfAttention(in_channels=64)

block(x).shape # (32, 64, 16, 16)

Global-Local Attention Module


An convolutional attention module introduced in the paper "All the attention you need: Global-local, spatial-channel attention for image retrieval.". Image also from this paper.


in_channels (int): number of channels of the input feature map

num_reduced_channels (int): number of channels that the local and global spatial attention modules will reduce the input feature map. Refer to figures 3 and 5 in the paper.

feaure_map_size (int): height/width of the feature map. The height/width of the input feature maps must be at least 7, due to the 7x7 convolution (3x3 dilated conv) in the module.

kernel_size (int): scope of the inter-channel attention


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)

block = modules.GLAM(in_channels=64, num_reduced_channels=48, feature_map_size=16, kernel_size=5)
# height and width is equal to feature_map_size

block(x).shape # (32, 64, 16, 16)

Global Context Module


A sort of self-attention (non-local) block on feature maps. Implementation of "GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond".


input_channels (int): Number of input channels


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)

block = modules.GlobalContextModule(input_channels=64)

block(x).shape # (32, 64, 16, 16)

LFSA Tokenizer and Refinement Block


Implementation of the tokenizer in "Learning Token-Based Representation for Image Retrieval" This are two modules: The tokenizer module that converts feature maps from a CNN (in the paper's case, feature maps from a local-feature-self-attention module) and tokenizes them "into L visual tokens". This is used prior to the refinement block as described in the paper. The refinement block "enhance[s] the obtained visual tokens with self-attention and cross-attention."


LFSA Tokenizer

in_channels (int): number of input channels

num_att_maps (int): number of tokens to tokenize the input into; also the number of channels used by the spatial attention

Refinement Block

d_model (int): dimensionality/channels of input

nhead (int): number of attention heads in the transformer

dim_feedforward (int): number of hidden dimensions in the feedforward layers

dropout (int): dropout rate


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)

tokenizer = modules.LFSATokenizer(in_channels=64, num_att_maps=48)
refinement_block = modules.RefinementBlock(d_model=64, nhead=2, dim_feedforward=48*4, dropout=0.1)

visual_tokens, cnn_output = tokenizer(x)
print(visual_tokens.shape) # (32, 48, 64)
print(cnn_output.shape) # (32, 16*16, 64)

output = refinement_block(visual_tokens, cnn_output)
print(output.shape) # (32, 48, 64)

Parameter-Free Channel Attention (PFCA)


A channel attention module for convolutional feature maps without any trainable parameters. Used in and image from the paper "PARAMETER-FREE CHANNEL ATTENTION FOR IMAGE CLASSIFICATION AND SUPER-RESOLUTION".


feature_map_size (int): Length/width of the input feature map

_lambda (float): A hyperparameter that is added to the variance (default: 1e-4)


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, channels, height, width)
block = modules.ParameterFreeChannelAttention(feature_map_size=16)

block(x).shape # (32, 64, 16, 16)

Patch Merger


Merges N tokens into M tokens in transformer models. Typically added in-between transformer layers. Introduced in the paper "LEARNING TO MERGE TOKENS IN VISION TRANSFORMERS". Image from this paper. Copied from lucidrains' repo under the MIT license.


dim (int): dimensionality/channels of the tokens

output_tokens (int): number of output merged tokens

norm (bool): normalize the input before merging

scale (bool): scale the attention matrix by the square root of dim (for numerical stability)


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16) # (batch_size, seq_length, channels)
block = modules.PatchMerger(dim=16, output_tokens=48, scale=True)

block(x).shape # (32, 48, 16)



Your basic residual block. Used in ResNets. Image from original paper "Deep Residual Learning for Image Recognition"


in_channels (int): number of input channels

kernel_size (int): kernel size

dropout (float): dropout rate


from torch_modules_compilation import modules

x = torch.randn(32, 64, 16, 16) # (batch_size, seq_length, channels)
block = modules.ResBlock(in_channels=64, kernel_size=3, dropout=0.2)

block(x).shape # (32, 64, 16, 16)

Up/Down sample ResBlock

Composed of several residual blocks and a down/up sampling at the end; adapted from Stable Diffusion's ResnetBlock.


in_channels (int): number of input channels

out_channels (int): number of output channels

num_groups (int): number of groups for Group Normalization

num_layers (int): number of residual blocks

dropout (float): dropout rate

sample (str): One of "down", "up", or "none". For downsampling 2x, use "down". For upsampling 2x, use "up". Use "none" for no down/up sampling.


from torch_modules_compilation import modules

x = torch.randn(32, 64, 96, 96) # (batch_size, channels, height, width)
block = modules.ResBlockUpDownSample(

block(x).shape # (32, 128, 48, 48)

Residual MLP Block

An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 5 in the paper).


dim (int): number of input dimensions

ic_first (bool): normalize and dropout at the start

dropout (float): dropout rate


from torch_modules_compilation import modules

x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_block(dim=96, ic_first=True, dropout=0.1)

block(x).shape # (32, 96)

Residual MLP Downsampling Block

An improvement of standard MLPs along with residual connections. From "Generalizing MLPs With Dropouts, Batch Normalization, and Skip Connections". This implements the residual MLP block (eq. 6 in the paper).


dim (int): number of input dimensions

downsample_dim (int): number of output dimensions

dropout (float): dropout rate


from torch_modules_compilation import modules

x = torch.randn(32, 96) # (batch_size, dim)
block = modules.ResidualMLP_downsample(dim=96, downsample_dim=48, dropout=0.1)

block(x).shape # (32, 48)

Transformer Encoder Layer

Standard transformer encoder layer with queries, keys, and values as inputs.


d_model (int): model dimensionality

nhead (int): number of attention heads

dim_feedforward (int): number of hidden dimensions in the feedforward layers

dropout (float): dropout rate

kdim (int, optional): dimensions of the keys

vdim (int, optional): dimensions of the values


from torch_modules_compilation import modules

queries = torch.randn(32, 20, 64) # (batch_size, seq_length, dim)
keys = torch.randn(32, 19, 48) # (batch_size, seq_length, dim)
values = torch.randn(32, 19, 96) # (batch_size, seq_length, dim)

block = modules.TransformerEncoderLayer(

block(queries, keys, values).shape # (32, 20, 64)

UNet Encoder and Decoder


Standard UNet implementation. From the paper U-Net: Convolutional Networks for Biomedical Image Segmentation.


UNet Encoder

channels (list): A list containing the number of channels in the encoder. E.g [3, 64, 128, 256]

dropout (float): dropout rate

UNet Decoder

channels (list of ints): A list containing the number of channels in the encoder. E.g. [256, 128, 64, 3]

dropout (float): dropout rate


from torch_modules_compilation import modules

images = torch.randn(16, 3, 224, 224) # (batch_size, channels, height, width)

unet_encoder = modules.UnetEncoder(channels=[3,64,128,256], dropout=0.1)
unet_decoder = modules.UnetDecoder(channels=[256,128,64,3], dropout=0.1)

encoder_features = unet_encoder(images)

output = unet_decoder(encoder_features)
print(output.shape) # (16, 64, 224, 224)

Squeeze-Excitation Module


Module that computes channel-wise interactions in a feature map. From Squeeze-and-Excitation Networks.


in_channels (int): Number of input channels

reduced_channels (int): Number of channels to reduce to in the "squeeze" part of the module

feature_map_size (int): height/width of the feature map


from torch_modules_compilation import modules

feature_maps = torch.randn(16, 128, 64, 64) # (batch_size, channels, height, width)
se_module = modules.SEModule(in_channels=128, reduced_channels=32, feature_map_size=64)

se_module(feature_maps) # shape (16, 128, 64, 64); same as input

Token Learner


Module designed for reducing and generating visual tokens given a feature map. From TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?


in_channels (int): Number of input channels

num_tokens (int): Number of tokens to reduce to


from torch_modules_compilation import modules

feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
token_learner = modules.TokenLearner(in_channels=16, num_tokens=50) # reduce tokens from 10*10 to 50

token_learner(feature_maps) # shape (2, 50, 16)

Triplet Attention


Computes attention in a feature map across all three dimensions (channel and both spatial dims). From Rotate to Attend: Convolutional Triplet Attention Module.


in_channels (int): Number of input channels

height (int): height of feature map

width (int): width of feature map

kernel_size (int): kernel size of the convolutions. Default: 7


from torch_modules_compilation import modules

feature_maps = torch.randn(2, 16, 10, 10) # (batch_size, channels, height, width)
triplet_attention = modules.TripletAttention(in_channels=16, height=10, width=10)

triplet_attention(feature_maps) # shape (2, 16, 10, 10); same as input


Unless specified, some of these modules are licensed under various licenses and/or copied from other repositories, such as MIT and Apache. Take note of these licenses when using these code in your work. The rest are of my own implementation, which is under the MIT license. See this repo's license file