How to do transfer learning with flaxmodels?
riven314 opened this issue · 1 comments
riven314 commented
I am sorry to write another issue again because I am trying to apply transfer learning on ResNet using your flaxmodels.
However, I got stuck on how to get the backbone and add a new head on top of the ResNet.
Do you have any sample code/ guidance for me as a reference?
riven314 commented
I saw example of transfer learning from flaxvision: https://github.com/rolandgvc/flaxvision/blob/master/examples/transfer_learning.ipynb
However, I feel like the script is not inheriting the pre-trained model weight in the training:
def get_initial_params(key):
init_shape = jnp.ones((1, 224, 224, 3), jnp.float32)
initial_params = MyModel().init(key, init_shape)['params']
return initial_params