/RepNeXt

RepNeXt: A Fast Multi-Scale CNN using Structural Reparameterization

Primary LanguagePythonApache License 2.0Apache-2.0

license arXiv Colab


The top-1 accuracy is tested on ImageNet-1K and the latency is measured by an iPhone 12 with iOS 16 across 20 experimental sets.

RepNeXt: A Fast Multi-Scale CNN using Structural Reparameterization.
Mingshu Zhao, Yi Luo, and Yong Ouyang [arXiv]

architecture

Abstract

We introduce RepNeXt, a novel model series integrates multi-scale feature representations and incorporates both serial and parallel structural reparameterization (SRP) to enhance network depth and width without compromising inference speed. Extensive experiments demonstrate RepNeXt's superiority over current leading lightweight CNNs and ViTs, providing advantageous latency across various vision benchmarks. RepNeXt-M4 matches RepViT-M1.5's 82.3% accuracy on ImageNet within 1.5ms on an iPhone 12, outperforms its AP$^{box}$ by 1.3 on MS-COCO, and reduces parameters by 0.7M.

transforms

Conclusion In this paper, we introduced a multi-scale depthwise convolution integrated with both serial and parallel SRP mechanisms, enhancing feature diversity and expanding the network’s expressive capacity without compromising inference speed. Specifically, we designed a reparameterized medium-kernel convolution to imitate the human foveal vision system. Additionally, we proposed our light-weight, general-purpose RepNeXts that employed the distribute-transform-aggregate design philosophy across inner-stage blocks as well as downsampling layers, achieving comparable or superior accuracy-efficiency trade-off across various vision benchmarks, especially on downstream tasks. Moreover, our flexible multi-branch design functions as a grouped-depthwise convolution with additional inductive bias and efficiency trade-offs. It can also be reparameterized into a single-branch large-kernel depthwise convolution, enabling potential optimization towards different accelerators.

For example, the large-kernel depthwise convolution can be accelerated by the implicit GEMM algorithm: DepthWiseConv2dImplicitGEMM of RepLKNet.

Many token mixers can be generalized as a distribute-transform-aggregate process:

Token Mixer Distribution Transforms Aggregation
ChunkConv Split Conv, Identity Cat
CopyConv Clone Conv Cat
MixConv Split Conv Cat
MHSA Split Attn Cat
RepBlock Clone Conv Add

ChunkConv and CopyConv can be viewed as grouped depthwise convolutions.

  • Chunk Conv
class ChunkConv(nn.Module):
    def __init__(self, in_channels, bias=True):
        super().__init__()
        self.bias = bias
        in_channels = in_channels // 4
        kwargs = {"in_channels": in_channels, "out_channels": in_channels, "groups": in_channels, "bias": bias}
        self.conv_i = nn.Identity()
        self.conv_s = nn.Conv2d(kernel_size=3, padding=1, **kwargs)
        self.conv_m = nn.Conv2d(kernel_size=7, padding=3, **kwargs)
        self.conv_l = nn.Conv2d(kernel_size=11, padding=5, **kwargs)

    def forward(self, x):
        i, s, m, l = torch.chunk(x, chunks=4, dim=1)
        return torch.cat((self.conv_i(i), self.conv_s(s), self.conv_m(m), self.conv_l(l)), dim=1)

    @torch.no_grad()
    def fuse(self):
        conv_s_w, conv_s_b = self.conv_s.weight, self.conv_s.bias
        conv_m_w, conv_m_b = self.conv_m.weight, self.conv_m.bias
        conv_l_w, conv_l_b = self.conv_l.weight, self.conv_l.bias

        conv_i_w = torch.nn.functional.pad(torch.ones(conv_l_w.shape[0], conv_l_w.shape[1], 1, 1), [5, 5, 5, 5])
        conv_s_w = nn.functional.pad(conv_s_w, [4, 4, 4, 4])
        conv_m_w = nn.functional.pad(conv_m_w, [2, 2, 2, 2])

        in_channels = self.conv_l.in_channels*4
        conv = nn.Conv2d(in_channels, in_channels, kernel_size=11, padding=5, bias=self.bias, groups=in_channels)
        conv.weight.data.copy_(torch.cat((conv_i_w, conv_s_w, conv_m_w, conv_l_w), dim=0))

        if self.bias:
            conv_i_b = torch.zeros_like(conv_s_b)
            conv.bias.data.copy_(torch.cat((conv_i_b, conv_s_b, conv_m_b, conv_l_b), dim=0))
        return conv
  • Copy Conv
class CopyConv(nn.Module):
    def __init__(self, in_channels, bias=True):
        super().__init__()
        self.bias = bias
        kwargs = {"in_channels": in_channels, "out_channels": in_channels, "groups": in_channels, "bias": bias, "stride": 2}
        self.conv_s = nn.Conv2d(kernel_size=3, padding=1, **kwargs)
        self.conv_l = nn.Conv2d(kernel_size=7, padding=3, **kwargs)
         
    def forward(self, x):
        B, C, H, W = x.shape
        s, l = self.conv_s(x), self.conv_l(x)
        return torch.stack((s, l), dim=2).reshape(B, C*2, H//2, W//2)

    @torch.no_grad()
    def fuse(self):
        conv_s_w, conv_s_b = self.conv_s.weight, self.conv_s.bias
        conv_l_w, conv_l_b = self.conv_l.weight, self.conv_l.bias

        conv_s_w = nn.functional.pad(conv_s_w, [2, 2, 2, 2])

        in_channels = self.conv_l.in_channels
        conv = nn.Conv2d(in_channels, in_channels*2, kernel_size=7, padding=3, bias=self.bias, stride=self.conv_l.stride, groups=in_channels)
        conv.weight.data.copy_(torch.stack((conv_s_w, conv_l_w), dim=1).reshape(conv.weight.shape))

        if self.bias:
            conv.bias.data.copy_(torch.stack((conv_s_b, conv_l_b), dim=1).reshape(conv.bias.shape))
        return conv

In summary, by focusing solely on the simplicity of the model’s overall architecture and disregarding its efficiency and parameter count, we can ultimately consolidate it into the single-branch structure shown in the figure below:

equivalent


UPDATES πŸ”₯

  • 2024/10/13: Added M0-M2 single branch equivalent form ImageNet-1K results using StarNet's training recipe.
  • 2024/09/19: Added M0-M2 ImageNet-1K results using StarNet's training recipe (distilled). Hit 80.6% top-1 accuracy within 1ms on an iPhone 12.
  • 2024/09/08: Added RepNext-M0 ImageNet-1K result using StarNet's training recipe. Achieving 73.8% top-1 accuracy without distillation.
  • 2024/08/26: RepNext-M0 (distilled) has been released, achieving 74.2% top-1 accuracy within 0.6ms on an iPhone 12.
  • 2024/08/23: Finished compact model (M0) ImageNet-1K experiments.
  • 2024/07/23: Updated readme about further simplified model structure.
  • 2024/06/25: Uploaded checkpoints and training logs of RepNext-M1 - M5.

Classification on ImageNet-1K

Models under the RepVit training strategy

We report the top-1 accuracy on ImageNet-1K with and without distillation using the same training strategy as RepViT.

Model Top-1(distill) / Top-1 #params MACs Latency Ckpt Core ML Log
M0 74.2 | 72.6 2.3M 0.4G 0.59ms fused 300e / 300e 300e distill 300e / 300e
M1 78.8 | 77.5 4.8M 0.8G 0.86ms fused 300e / 300e 300e distill 300e / 300e
M2 80.1 | 78.9 6.5M 1.1G 1.00ms fused 300e / 300e 300e distill 300e / 300e
M3 80.7 | 79.4 7.8M 1.3G 1.11ms fused 300e / 300e 300e distill 300e / 300e
M4 82.3 | 81.2 13.3M 2.3G 1.48ms fused 300e / 300e 300e distill 300e / 300e
M5 83.3 | 82.4 21.7M 4.5G 2.20ms fused 300e / 300e 300e distill 300e / 300e

Models under the StarNet training strategy

We report the top-1 and top-5 accuracy on ImageNet-1K with and without distillation using the same training strategy as StarNet.

Models decorated with β€œE” denote the single-branch equivalent form of RepNeXts through channel-wise structural reparameterization.

Model Params (M) MACs Latency Top-1 / top-5 Download Top-1 / top-5 (distill) Download
M0 2.3 0.4G 0.59ms 73.8 | 91.6 args | log | model 75.4 | 92.1 args | log | model
M0E 2.5 0.5G 0.62ms 73.9 | 91.9 args | log 75.6 | 92.3 args | log
M1 4.8 0.8G 0.86ms 77.9 | 94.0 args | log | model 79.7 | 94.5 args | log | model
M1E 5.3 1.0G 0.89ms 77.7 | 93.9 args | log 79.5 | 94.5 args | log
M2 6.5 1.1G 1.00ms 78.8 | 94.5 args | log | model 80.6 | 95.1 args | log | model
M2E 7.0 1.3G 1.04ms 78.8 | 94.3 args | log 80.6 | 95.2 args | log

Model evaluation:

python moganet_valid.py \
--model repnext_m0 \
--img_size 224 \
--crop_pct 0.9 \
--data_dir data/imagenet \                  
--checkpoint repnext_m0_sz224_4xbs512_ep300.pth.tar

Tips: Convert a training-time RepNeXt into the inference-time structure

from timm.models import create_model
import utils

model = create_model('repnext_m1')
utils.replace_batchnorm(model)

Latency Measurement

The latency reported in RepNeXt for iPhone 12 (iOS 16) uses the benchmark tool from XCode 14.

RepNeXt-M1
RepNeXt-M2
RepNeXt-M3
RepNeXt-M4
RepNeXt-M5

Tips: export the model to Core ML model

python export_coreml.py --model repnext_m1 --ckpt pretrain/repnext_m1_distill_300e.pth

Tips: measure the throughput on GPU

python speed_gpu.py --model repnext_m1

ImageNet

Prerequisites

conda virtual environment is recommended.

conda create -n repnext python=3.8
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The training and validation data are expected to be in the train folder and val folder respectively:

# script to extract ImageNet dataset: https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh
# ILSVRC2012_img_train.tar (about 138 GB)
# ILSVRC2012_img_val.tar (about 6.3 GB)
# organize the ImageNet dataset as follows:
imagenet
β”œβ”€β”€ train
β”‚   β”œβ”€β”€ n01440764
β”‚   β”‚   β”œβ”€β”€ n01440764_10026.JPEG
β”‚   β”‚   β”œβ”€β”€ n01440764_10027.JPEG
β”‚   β”‚   β”œβ”€β”€ ......
β”‚   β”œβ”€β”€ ......
β”œβ”€β”€ val
β”‚   β”œβ”€β”€ n01440764
β”‚   β”‚   β”œβ”€β”€ ILSVRC2012_val_00000293.JPEG
β”‚   β”‚   β”œβ”€β”€ ILSVRC2012_val_00002138.JPEG
β”‚   β”‚   β”œβ”€β”€ ......
β”‚   β”œβ”€β”€ ......

Training

To train RepNeXt-M1 on an 8-GPU machine:

python -m torch.distributed.launch --nproc_per_node=8 --master_port 12346 --use_env main.py --model repnext_m1 --data-path ~/imagenet --dist-eval

Tips: specify your data path and model name!

Training with the helper script:

sh dist_train_cifar.sh

Testing

For example, to test RepNeXt-M1:

python main.py --eval --model repnext_m1 --resume pretrain/repnext_m1_distill_300e.pth --data-path ~/imagenet

Fused model evaluation

For example, to evaluate RepNeXt-M1 with the fused model: Colab

python fuse_eval.py --model repnext_m1 --resume pretrain/repnext_m1_distill_300e_fused.pt --data-path ~/imagenet

Downstream Tasks

Object Detection and Instance Segmentation

Model $AP^b$ $AP_{50}^b$ $AP_{75}^b$ $AP^m$ $AP_{50}^m$ $AP_{75}^m$ Latency Ckpt Log
RepNeXt-M3 40.8 62.4 44.7 37.8 59.5 40.6 5.1ms M3 M3
RepNeXt-M4 42.9 64.4 47.2 39.1 61.7 41.7 6.6ms M4 M4
RepNeXt-M5 44.7 66.0 49.2 40.7 63.5 43.6 10.4ms M5 M5

Semantic Segmentation

Model mIoU Latency Ckpt Log
RepNeXt-M3 40.6 5.1ms M3 M3
RepNeXt-M4 43.3 6.6ms M4 M4
RepNeXt-M5 45.0 10.4ms M5 M5

Feature Map Visualization

Run feature map visualization demo: Colab

Original Image Identity RepDWConvS RepDWConvM DWConvL

Ablation Study

Downsampling Layer Design

The downsampling layer between each stage is a modified version of the MetaNeXt block, where the shortcut connection bypasses the channel mixer.

When replace downsampling layers of ConvNeXt-femto with our designs, the top-1 accuracy is improved by 1.8%.
Model Top-1(%) Params(M) GMACs Throughput(im/s) Log Ckpt
ConvNeXt 72.37 5.22 0.78 3636 femto baseline
ModifiedA 74.28 5.25 0.79 3544 femto replaced
ModifiedB 74.19 5.25 0.79 3544 femto replaced

top1_acc

# ModifiedA
class Downsample(nn.Module):
    def __init__(self, dim, mlp_ratio):
        super().__init__()
        out_dim = dim * 2
        self.dwconv = nn.Conv2d(dim, out_dim, kernel_size=7, padding=3, groups=dim, stride=2)
        self.norm = LayerNorm(out_dim, eps=1e-6)
        self.pwconv1 = nn.Linear(out_dim, mlp_ratio * out_dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(mlp_ratio * out_dim, out_dim)

    def forward(self, x):
        x = self.dwconv(x)        # token mixer: (N, C, H, W) -> (N, 2C, H/2, W/2)
        input = x                 # bypass the channel mixer and the normalization layer
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return input + x

The code below is our downsample layer design, where the shortcut connection only bypasses the channel mixer. This design enables the batch normalization layer to be fused into the previous convolution layer.

# ModifiedB
class Downsample(nn.Module):
    def __init__(self, dim, mlp_ratio=2):
        super().__init__()
        out_dim = dim * 2
        self.dwconv = nn.Conv2d(dim, out_dim, kernel_size=7, padding=3, groups=dim, stride=2)
        self.norm = LayerNorm(out_dim, eps=1e-6, data_format="channels_first")
        self.pwconv1 = nn.Linear(out_dim, mlp_ratio * out_dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(mlp_ratio * out_dim, out_dim)

    def forward(self, x):
        x = self.dwconv(x)        # token mixer: (N, C, H, W) -> (N, 2C, H/2, W/2)
        x = self.norm(x)          # normalization layer
        input = x                 # bypass the channel mixer
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return input + x

Compact Model Experiments

Cifar-100

Model Latency (ms) Params (M) GMACs Top-1 (100e) Top-1 (200e) Top-1 (300e) Top-1 (400e) Log
RepViT-M0.6 0.6 2.2 0.39 71.73 78.45 80.16 80.60 100e / 200e / 300e / 400e
RepNeXt-M0E 0.7 2.3 0.46 73.56 79.79 81.89 82.07 100e / 200e / 300e / 400e
RepNeXt-M0 0.6 2.0 0.39 74.53 80.14 81.73 81.97 100e / 200e / 300e / 400e

top1_acc_cifar_100

RepNeXt-M0E is the equivalent form of RepNeXt-M0 where the multi-branch design is replaced by the single-branch large-kernel depthwise convolution. Our multi-branch reparameter design helps the model converge faster, with a smaller size and lower latency.

ImageNet-1K

Model Latency (ms) Params (M) GMACs Top-1 (100e) Top-1 (200e) Top-1 (300e) Top-1 (400e) Log
RepViT-M0.6 0.6 2.5 0.39 69.29 71.63 72.34 72.87 100e / 200e / 300e / 400e
RepNeXt-M0E 0.7 2.5 0.46 69.69 71.57 72.18 72.53 100e / 200e / 300e / 400e
RepNeXt-M0 0.6 2.3 0.39 70.14 71.93 72.56 72.78 100e / 200e / 300e / 400e

top1_acc_cifar_100

RepNeXt-M0E is the equivalent form of RepNeXt-M0 where the multi-branch design is replaced by the single-branch large-kernel depthwise convolution. Our multi-branch reparameter design helps the model converge faster, with a smaller size and lower latency.

Acknowledgement

Classification (ImageNet) code base is partly built with LeViT, PoolFormer, EfficientFormer, RepViT, and MogaNet.

The detection and segmentation pipeline is from MMCV (MMDetection and MMSegmentation).

Thanks for the great implementations!

Citation

If our code or models help your work, please cite our paper:

@misc{zhao2024repnext,
      title={RepNeXt: A Fast Multi-Scale CNN using Structural Reparameterization},
      author={Mingshu Zhao and Yi Luo and Yong Ouyang},
      year={2024},
      eprint={2406.16004},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}