Fine-Tuning
IMvision12 opened this issue · 3 comments
I have a flax model :
b = mlpmixer_b16(num_classes=10)
And pre-trained weights (ImageNet) (Image size: 224x224)
with open("imagenet21k_Mixer-B_16.msgpack", "rb") as f:
content = f.read()
restored_params = flax.serialization.msgpack_restore(content)
So, I want to fine-tune this model with restored_params on a dataset having images of size 128x128
when i try to init or apply i get this error:
dummy_inputs = jnp.ones((1, 128, 128, 3), dtype=jnp.float32)
rng = jax.random.PRNGKey(0)
x = b.apply({"params": restored_params}, dummy_inputs)
ScopeParamShapeError: Initializer expected to generate shape (196, 384) but got shape (64, 384) instead for parameter "kernel" in "/MixerBlock_0/token_mixing/Dense_0". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)
If I change the shape to 224x224 it works fine:
jnp.ones((1, 224, 224, 3), dtype=jnp.float32)
How to properly finetune a model using flax?
When you want to fine-tune a ViT model with a different image size than it was pre-trained on, then you'll need to adjust the position embeddings accordingly. Section 3.2 of the ViT Paper proposes to perform 2D interpolation.
This is supported in this codebase when loading a checkpoint:
vision_transformer/vit_jax/checkpoint.py
Lines 192 to 201 in 297866a
Which is done automatically when you call checkpoint.load_pretrained()
and provide both init_params
that expect a certain image size (e.g. 128 in your example), and load from a checkpoint that has weights that were trained on a different size (e.g. 224).
See the code in the main Colab that fine-tunes on on cifar10 (size 32), specifically this cell:
Thanks, is there such thing for mlp-mixer models? @andsteing
It's not as straight-forward with mlp-mixer, because the token mixing MLP blocks are trained for a specific number of patches (as opposed to ViT, where the attention operation works on sequences of variable length, and only the position embeddings need to be modified).
See Section C from the Mixer paper for details how to do this.
(not implemented in this repo)