/embedding-transfer

Code for “Bridging Subword Gaps in Pretrain-Finetune Paradigm for Natural Language Generation” (ACL2021)

Primary LanguagePython

Code for "Bridging Subword Gaps in Pretrain-Finetune Paradigm for Natural Language Generation" (ACL 2021)

Description for directories

  • poattention (modified from Fairseq): Training the Position-Aware Embedding Generator for seq2seq models.
  • use_poattention (modified from Fairseq): Generating embeddings for unseen tokens as well as fine-tuning the seq2seq model with a vocabulary for downstream data under the downstream task.
  • bert_poattention (modified from Transformers): Training the Position-Aware Embedding Generator for bert-like models.
  • bert_use_poattention (modified from Fairseq): Generating embeddings for unseen tokens, converting parameters of bert-like model to seq2seq one, as well as fine-tuning the seq2seq model with a newly generated vocabulary under the downstream task.

How to run

For seq2seq pretrained model

poattention

  1. Preprocess upstream and downstream data (refer to Fairseq for details). Binarized data and vocabularies will be stored in data-bin

  2. Move the seq2seq pretrained model (generated by Fairseq) to ./checkpoints and rename it as checkpoint_last.pt.

    cp path_to_pretrained_model ./checkpoints/checkpoint_last.pt

  3. Train the embedding generator

    pip install .; bash train.sh

  4. Stop training when model tends to coverage.

use_poattention

  1. Preprocess upstream and downstream data (refer to Fairseq for details). Binarized data and vocabularies will be stored in data-bin

  2. Get the mapping between upstream and downstream vocabulary.

    python get_map_index.py

    Note: please change the data name in get_map_index.py

  3. Move the well-trained embedding genearator checkpoint (generated by poattention) to ./checkpoints and rename it as checkpoint_last.pt.

    cp path_to_embedding_generator ./checkpoints/checkpoint_last.pt

  4. Generate unseen tokens and finetune the downstream model with downstream vocabulary.

    pip install .; bash train.sh

For bert-like pretrained model

bert_poattention

  1. Prepare the upstream data (plain text) at ./examples/language-modeling/data.

  2. Train the embedding generator

    pip install .

    cd ./examples/language-modeling

    bash train_mlm.sh

bert_use_poattention

  1. Preprocess upstream and downstream data (refer to Fairseq for details). Binarized data and vocabularies will be stored in data-bin.

    Note: Sentences should be cutted by WordPiece, I suggest the bert-vocab-builder for building the vocabulary of downstream data.

  2. Get the mapping between upstream and downstream vocabulary.

    python get_map_index.py

    Note: please change the data name in get_map_index.py

  3. Generate unseen tokens and finetune the downstream model with downstream vocabulary.

    pip install path_to_bert_poattention

    pip install .; bash train.sh