/LambdaNetworks

Implementing Lambda Networks using Pytorch

Primary LanguagePythonMIT LicenseMIT

LambdaNetworks: Modeling long-range Interactions without Attention

Experimnets (CIFAR10)

Model k h u m Params (M) Acc (%)
ResNet18 baseline (ref) 14 93.02
LambdaResNet18 16 4 4 9 8.6 92.21 (70 Epochs)
LambdaResNet18 16 4 4 7 8.6 94.20 (67 Epochs)
LambdaResNet18 16 4 4 5 8.6 91.58 (70 Epochs)
LambdaResNet18 16 4 1 23 8 91.36 (69 Epochs)
ResNet50 baseline (ref) 23.5 93.62
LambdaResNet50 16 4 4 7 13 93.74 (70 epochs)

Usage

import torch

from model import LambdaConv, LambdaResNet50, LambdaResNet152

x = torch.randn([2, 3, 32, 32])
conv = LambdaConv(3, 128)
print(conv(x).size()) # [2, 128, 32, 32]

# reference
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

model = LambdaResNet50()
print(get_n_params(model)) # 14.9M (Ours) / 15M(Paper)

model = LambdaResNet152()
print(get_n_params(model)) # 32.8M (Ours) / 35M (Paper)

Parameters

Model k h u m Params (M), Paper Params (M), Ours
LambdaResNet50 16 4 1 23 15.0 14.9
LambdaResNet50 16 4 4 7 16.0 16.0
LambdaResNet152 16 4 1 23 35 32.8
LambdaResNet200 16 4 1 23 42 35.29

Ablation Parameters

k h u Params (M), Paper Params (M), Ours
ResNet baseline 25.6 25.5
8 2 1 14.8 15.0
8 16 1 15.6 14.9
2 4 1 14.7 14.6
4 4 1 14.7 14.66
8 4 1 14.8 14.66
16 4 1 15.0 14.99
32 4 1 15.4 15.4
2 8 1 14.7 14.5
4 8 1 14.7 14.57
8 8 1 14.7 14.74
16 8 1 15.1 14.1
32 8 1 15.7 15.76
8 8 4 15.3 15.26
8 8 8 16.0 16.0
16 4 4 16.0 16.0