LambdaNetworks: Modeling long-range Interactions without Attention
Experimnets (CIFAR10)
Model |
k |
h |
u |
m |
Params (M) |
Acc (%) |
ResNet18 baseline (ref) |
|
|
|
|
14 |
93.02 |
LambdaResNet18 |
16 |
4 |
4 |
9 |
8.6 |
92.21 (70 Epochs) |
LambdaResNet18 |
16 |
4 |
4 |
7 |
8.6 |
94.20 (67 Epochs) |
LambdaResNet18 |
16 |
4 |
4 |
5 |
8.6 |
91.58 (70 Epochs) |
LambdaResNet18 |
16 |
4 |
1 |
23 |
8 |
91.36 (69 Epochs) |
ResNet50 baseline (ref) |
|
|
|
|
23.5 |
93.62 |
LambdaResNet50 |
16 |
4 |
4 |
7 |
13 |
93.74 (70 epochs) |
Usage
import torch
from model import LambdaConv, LambdaResNet50, LambdaResNet152
x = torch.randn([2, 3, 32, 32])
conv = LambdaConv(3, 128)
print(conv(x).size()) # [2, 128, 32, 32]
# reference
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
model = LambdaResNet50()
print(get_n_params(model)) # 14.9M (Ours) / 15M(Paper)
model = LambdaResNet152()
print(get_n_params(model)) # 32.8M (Ours) / 35M (Paper)
Parameters
Model |
k |
h |
u |
m |
Params (M), Paper |
Params (M), Ours |
LambdaResNet50 |
16 |
4 |
1 |
23 |
15.0 |
14.9 |
LambdaResNet50 |
16 |
4 |
4 |
7 |
16.0 |
16.0 |
LambdaResNet152 |
16 |
4 |
1 |
23 |
35 |
32.8 |
LambdaResNet200 |
16 |
4 |
1 |
23 |
42 |
35.29 |
Ablation Parameters
k |
h |
u |
Params (M), Paper |
Params (M), Ours |
ResNet baseline |
|
|
25.6 |
25.5 |
8 |
2 |
1 |
14.8 |
15.0 |
8 |
16 |
1 |
15.6 |
14.9 |
2 |
4 |
1 |
14.7 |
14.6 |
4 |
4 |
1 |
14.7 |
14.66 |
8 |
4 |
1 |
14.8 |
14.66 |
16 |
4 |
1 |
15.0 |
14.99 |
32 |
4 |
1 |
15.4 |
15.4 |
2 |
8 |
1 |
14.7 |
14.5 |
4 |
8 |
1 |
14.7 |
14.57 |
8 |
8 |
1 |
14.7 |
14.74 |
16 |
8 |
1 |
15.1 |
14.1 |
32 |
8 |
1 |
15.7 |
15.76 |
8 |
8 |
4 |
15.3 |
15.26 |
8 |
8 |
8 |
16.0 |
16.0 |
16 |
4 |
4 |
16.0 |
16.0 |