huggingface/transformers

Pretraining BART language model

sajastu opened this issue · 6 comments

Feature request

Hi,

I'm looking into BART docs. Seems that the provided examples are on fine-tuning BART on Seq2Seq summarization tasks. I'm wondering if there is any example on pertaining BART's "Language Model" itself, with the pre-training objectives (Token infilling, Token Masking, etc.) that are mentioned in the original paper. I was looking into this a couple of months ago and found this thread: #4151 and a relevant issue in fairseq: facebookresearch/fairseq#1899. Now decided to ask it directly here to see if there has been any update so far.

Thanks, @patrickvonplaten @patil-suraj,

Motivation

Making BART for further pre-training (on Language Model).

I don't think we have plans currently to add code for BART pretraining as it's quite a time-costly endeavor . Maybe we should revisit this decision though at some point for some models (BART, T5, Wav2Vec2) as the community is asking more and more. Wdyt @sgugger @LysandreJik ?

It's not something anyone on the team has the time to focus on right now, but we can welcome a contribution for such an example.

I have attempted to reproduce BART pretraining for Swedish using fairseq here with some instructions in the README. We trained a custom tokenizer with Huggingface tokenizers package and figured we'd later retrofit the dict.txt generated by fairseq preprocessing back to a Huggingface compatible tokenizer.json. This worked, but we later discovered our chosen method of doing it was unnecessarily complicated, as we could have just copy pasted the vocabulary from our original tokenizer.json to a tab separated dict.txt where the 1st column is the tokens, and the 2nd column can be dummy frequency counts of how often the tokens occur in your tokenized data (You can set the frequencies to 0 or any integer).

We were unfamiliar with fairseq prior to doing this and based the entire reproduction off of reading the paper closely and trying to match the reported experimental settings in the paper against fairseq args by reading the documentation and the source code.

The finished model was uploaded to Huggingface. However, we learned quite a few lessons along the way that aren't documented there.

The first one being the easiest process of using a custom Huggingface tokenizer with fairseq that I described above.

A second important one was that you should generate several epochs worth of sharded and shuffled data and not just a shards for one single epoch. Fairseq pretraining will crash once you reach the end of your listed shards, and won't allow you to reuse or relist the same shard filename. I guess this sort of stuff maybe is obvious if you have the chance to ask your Facebook engineer colleagues questions over a lunch, but to us it wasn't obvious what best practices were, since it's not described in the docs.

Anyway. Maybe the pretraining task along with the training args I researched in trying to replicate it may be of interest to you: https://github.com/kb-labb/kb_bart/blob/main/train_bart_args.sh . I cannot guarantee that it is identical to actual BART pretraining settings, but it was my best attempt based on reading the paper and source code.

Edit: Seems BART author replied in the issue thread with the actual args they used in pretraining .

@sajastu
You can use this repo and in the HF run_mlm.py script you have to:

  • replace AutoModelForMaskedLM by AutoModelForSeq2SeqLM
  • replace data_collator by DataCollatorForDenoisingTasks (token infilling + sentence permutation)
  • or replace data_collator by DataCollatorForTextInfilling (you need to process decoder_input_ids for this one)

I'm not the author of this repo but @patrickvonplaten is one of them.
You need to convert numpy outputs from the collator to torch ones if you dont want to rewrite it.
You can also check #12370 for the pytorch implementation of the DataCollatorForTextInfilling which is very similar.

That's it!

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

We've released nanoT5 that reproduces T5-model (similar to BART) pre-training. You can take a look! Any suggestions are more than welcome.