How to fine tune masked language model task?
Closed this issue · 5 comments
Hi there,
I was reading through your example language modelling notebook and not sure how to adapt the same notebook (CLM) to MLM. Basically I was trying to do some this:
model_cls = AutoModelForMaskedLM
pretrained_model_name = 'roberta-base' # cause gpt2 cannot be used for MLM
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, model_cls=model_cls)
However, I'm stuck at how to modify HF_CausalLMBeforeBatchTransform
and there's no such equivalent transform for MaskedLM in the library
blocks = (
HF_Seq2SeqBlock(before_batch_tfm=HF_CausalLMBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model)),
noop
)
Could you have a look at this? Thank you in advance!
Yah, that is in the TODO list at the moment :)
There are a number of masking strategies (prefix language modeling, BERT-style, deshuffling, MASS-style, replace spans, drop tokens, random spans, etc...). See the T5 paper, Table 3 here: https://arxiv.org/abs/1910.10683
What I envision is a HF_MLMBeforeBatchTransform
that takes a MLM_{Whatever}Strategy
class that knows how to modify the inputs/targets accordingly. That object would essentially do what you see here in the causal LM batch transform.
You want to give it a shot?
If so, here's some tips for how you might go about implementing this:
- The notebook to modify is
01zb_data-seq2seq-language-modeling
(see the Masked LM section at the bottom) - Start by creating an abstract base class called
MLM_MaskingStrategy
. Anything we start finding in common with strategies we just stick in here. - Then start with the most common corruption strategy,
MLM_BertMaskingStrategy
class that inherits fromMLM_MaskingStrategy
. We're going to want to pass it thesamples
and have it return to us ourupdated_samples
, where the inputs have been corrupted (masked) and the targets are the original text.
I think the above approach will work. We can reduce the code further by making a CausalMaskingStrategy
and then turning the batch transform into just HF_LMBeforeBatchTransform
.
Lmk if you give this a try.
If not, I'll try to work up the basic infrastructure for it and you can, if you choose to do so, maybe add some of the other strategies for MLM.
Sorry I still couldn't make it work on my own based on your hints :(
Yes, I think a basic infrastructure would be very helpful. I think I can modify accordingly based on my need then. Thanks a lot!
Btw, I was also thinking of making use of the DataCollatorForLanguageModeling
class in transformers.data.data_collator
, do you think it can be used in your HF_LMBeforeBatchTransform
?
Ok take a look at the repo. You'll have to do a dev install as I haven't pushed a new release out yet.
The notebooks to check out:
- https://github.com/ohmeow/blurr/blob/master/nbs/01zb_data-seq2seq-language-modeling.ipynb
- https://github.com/ohmeow/blurr/blob/master/nbs/02zb_modeling-seq2seq-language-modeling.ipynb
Things that would be helpful:
- Review the BERT-style masking code ... verify it conforms to paper and is there a way to make it more efficient?
- Add other masking strategies that derive from
LMStrategy
. The T5 paper referenced in the notebooks/docs describes the core ones ... give it a go! you can do it if I can :)
Closing this out.
Now that v.1 is out, feel free to PR any new masking strategies you want added. I'm hoping to get some time for this later this year. Thanks.