google-research/vision_transformer

Fine-Tuning

IMvision12 opened this issue · 3 comments

I have a flax model :

b = mlpmixer_b16(num_classes=10)

And pre-trained weights (ImageNet) (Image size: 224x224)

 with open("imagenet21k_Mixer-B_16.msgpack", "rb") as f:
        content = f.read()
        restored_params = flax.serialization.msgpack_restore(content)

So, I want to fine-tune this model with restored_params on a dataset having images of size 128x128
when i try to init or apply i get this error:

dummy_inputs = jnp.ones((1, 128, 128, 3), dtype=jnp.float32)
rng = jax.random.PRNGKey(0)
x = b.apply({"params": restored_params}, dummy_inputs)

ScopeParamShapeError: Initializer expected to generate shape (196, 384) but got shape (64, 384) instead for parameter "kernel" in "/MixerBlock_0/token_mixing/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

If I change the shape to 224x224 it works fine:

jnp.ones((1, 224, 224, 3), dtype=jnp.float32)

How to properly finetune a model using flax?

When you want to fine-tune a ViT model with a different image size than it was pre-trained on, then you'll need to adjust the position embeddings accordingly. Section 3.2 of the ViT Paper proposes to perform 2D interpolation.

This is supported in this codebase when loading a checkpoint:

if 'posembed_input' in restored_params.get('Transformer', {}):
# Rescale the grid of position embeddings. Param shape is (1,N,1024)
posemb = restored_params['Transformer']['posembed_input']['pos_embedding']
posemb_new = init_params['Transformer']['posembed_input']['pos_embedding']
if posemb.shape != posemb_new.shape:
logging.info('load_pretrained: resized variant: %s to %s', posemb.shape,
posemb_new.shape)
posemb = interpolate_posembed(
posemb, posemb_new.shape[1], model_config.classifier == 'token')
restored_params['Transformer']['posembed_input']['pos_embedding'] = posemb

Which is done automatically when you call checkpoint.load_pretrained() and provide both init_params that expect a certain image size (e.g. 128 in your example), and load from a checkpoint that has weights that were trained on a different size (e.g. 224).

See the code in the main Colab that fine-tunes on on cifar10 (size 32), specifically this cell:

https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax.ipynb#scrollTo=zIXjOEDkvAWM

Thanks, is there such thing for mlp-mixer models? @andsteing

It's not as straight-forward with mlp-mixer, because the token mixing MLP blocks are trained for a specific number of patches (as opposed to ViT, where the attention operation works on sequences of variable length, and only the position embeddings need to be modified).

See Section C from the Mixer paper for details how to do this.

(not implemented in this repo)