/jax-resnet

Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).

Primary LanguagePythonMIT LicenseMIT

JAX ResNet - Implementations and Checkpoints for ResNet Variants

Build & Tests

Note that this is a FORKED from the original repo n2cholas/jax-resnet.

A Flax (Linen) implementation of ResNet (He et al. 2015), Wide ResNet (Zagoruyko & Komodakis 2016), ResNeXt (Xie et al. 2017), ResNet-D (He et al. 2020), and ResNeSt (Zhang et al. 2020). The code is modular so you can mix and match the various stem, residual, and bottleneck implementations.

Changes

More utilities for transfer learning. See Transfer Learning.

Installation

You can install this package from PyPI:

pip install jax-resnet

Or directly from GitHub:

pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git

Usage

See the bottom of jax-resnet/resnet.py for the available aliases/options for the ResNet variants (all models are in Flax)

Pretrained checkpoints from torch.hub are available for the following networks:

  • ResNet [18, 34, 50, 101, 152]
  • WideResNet [50, 101]
  • ResNeXt [50, 101]
  • ResNeSt [50-Fast, 50, 101, 200, 269]

The models are tested to have the same intermediate activations and outputs as the torch.hub implementations, except ResNeSt-50 Fast, whose activations don't match exactly but the final accuracy does.

A pretrained checkpoint for ResNetD-50 is available from fast.ai. The activations do not match exactly, but the final accuracy matches.

import jax.numpy as jnp
from jax_resnet import pretrained_resnest

ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
out = model.apply(variables,
                  jnp.ones((32, 224, 224, 3)),  # ImageNet sized inputs.
                  mutable=False)  # Ensure `batch_stats` aren't updated.

You must install PyTorch yourself (instructions) to use these functions.

Transfer Learning

To extract a subset of the model, you can use Sequential(model.layers[start:end]).

The slice_variables function (found in in common.py) allows you to extract the corresponding subset of the variables dict. Check out that docstring for more information.

More Utilities

Retrieving the intermediate outputs from each stage of ResNet. This feature is usually used for building feature pyramid networks for object detection model.

from jax_resnet import pretrained_resnet
ResNet50, variables = pretrained_resnet(50)
model = ResNet50(output_stages=[3, 4, 5])

print(model.output_layers) # 8, 14, 17

out, (out3, out4, out5) = model.apply(
    variables,
    jnp.ones((32, 224, 224, 3)),
    mutable=False
)

ResNet50 borrowrd from Towards Data Science

Slicing ResNet and variables.

from jax_resnet import pretrained_resnet, slice_resnet_and_variables

ResNet50, variables = pretrained_resnet(50)
model = ResNet50(output_stages=[3, 4, 5])
sliced_resnet, sliced_variables = slice_resnet_and_variables(
        model, variables, start=0, end=-2)

out, (out3, out4, out5) = sliced_resnet.apply(
    sliced_variables,
    # You can pass arbitrary size of input
    jnp.ones((32, 480, 640, 3)),
    mutable=False
)
print(out3.shape, out4.shape, out5.shape)

Note that the stage number you specified in output_stages will be discarded if the corresponding stage is sliced. For example:

start = model.output_layers[0]+1   # 9
sliced_resnet, sliced_variables = slice_resnet_and_variables(
        model, variables, start=start, end=-2)

# out3 has been sliced out
out, (out4, out5) = sliced_resnet.apply(
    sliced_variables,
    jnp.ones((32, 480, 640, 3)),
    mutable=False
)

Checkpoint Accuracies

The top 1 and top 5 accuracies reported below are on the ImageNet2012 validation split. The data was preprocessed as in the official PyTorch example.

Model Size Top 1 Top 5
ResNet 18 69.75% 89.06%
34 73.29% 91.42%
50 76.13% 92.86%
101 77.37% 93.53%
152 78.30% 94.04%
Wide ResNet 50 78.48% 94.08%
101 78.88% 94.29%
ResNeXt 50 77.60% 93.70%
101 79.30% 94.51%
ResNet-D 50 77.57% 93.85%

The ResNeSt validation data was preprocessed as in zhang1989/ResNeSt.

Model Size Crop Size Top 1 Top 5
ResNeSt-Fast 50 224 80.53% 95.34%
ResNeSt 50 224 81.05% 95.42%
101 256 82.82% 96.32%
200 320 83.84% 96.86%
269 416 84.53% 96.98%

References