Modify checkpoint partially
KeremTurgutlu opened this issue · 1 comments
KeremTurgutlu commented
Is there a way to partially load a t5x checkpoint and modify it? I would like to take pretrained checkpoint, say t5-small, and change the token embedder weights by adding new ids and overwriting the weights of some of the ids. Mainly I would like to finetune the model but with a partially changed (randomly initialized) embedding table.
Step by step:
- Train a sentencepiece model with a specialized corpus (e.g. medicine, politics etc). Let's say
piece_to_ids={e:0,a:1,f:2}
- Load
target.token_embedder.embedding
let's say original sp model haspiece_to_ids = {a:0,b:1,c:2,d:4}
and an embedding table like:
[0.10,0.10, 0.10] # a
[0.20,0.20, 0.20] # b
[0.30,0.30, 0.30] # c
[0.40,0.40, 0.40] # d
- Keep the common tokens, add the new tokens randomly, and order according to new sp model ids.
[random init] # e
[0.10,0.10,0.10] # a
[random init] # f
- Save back checkpoint portion to t5x checkpoint dir.
- Modify gin file to use size 3 embedding table and use the new checkpoint for finetuning in t5x.
In my case I only change the embedding table but this use case can of course be generalized to any partial modification to an existing pretrained model.
- is it possible to do the above without a TPU VM? All the t5x tutorials uses a TPU VM.
For example, I can load as in kaggle:
t5x_model = checkpoints.load_t5x_checkpoint("/kaggle/input/t5x-small/checkpoint_1000000/")
but this would require a TPU VM.
- Loading the full checkpoint of larger models won't be possible with this method to modify it so how can this be done by reading only interested portions?
KeremTurgutlu commented
This was possible using the InteractiveModel
class.
- Load model as interactive with the desired pretrained checkpoints which will be later used partially for pre-finetuning.
- Modify the checkpoints, e.g. do the desired model surgery on the unfrozen
params
dict.
params = flax.core.unfreeze(interactive_model._train_state.params)
# mock a new spiece model with different vocab size
params['token_embedder']['embedding'] = params['token_embedder']['embedding'][:1128] # always use multiple of 128
- Create a new randomly initialized interactive model which follows the same model structure as the latest modified params.
- Replace the random params with modified params and save.
interactive_model._trainer.train_state = interactive_model._trainer.train_state.replace_params(params)
interactive_model.save() # save method uses self._trainer.train_state
Let me know of any comments, thanks!