/PlugLM

[ACL2023] Source code for Decouple knowledge from paramters for plug-and-play language modeling

Primary LanguagePython

Decouple knowledge from parameters for plug-and-play language modeling

This reposotory contains the code and data for this ACL2023 Findings paper Decouple knowledge from parameters for plug-and-play language modeling

For the first time, we challenge the current implicit knowledge encoding mechanism for PLMs, which have two fundamental drawbacks:

  • The knowledge is neither editable nor scalable once the model is trained, which is especially problematic in that knowledge is consistently evolving.
  • It lacks interpretability and prevents humans from understanding which knowledge PLM requires for a certain problem.

Based on the discovery that the Feed-Forward Network in PLM has the ability to store various types of knowledge during pre-training and is fundamentally a key-value memory network, we present PlugLM. This is the first model to separate knowledge storage from model parameters by utilizing an editable and scalable key-value memory, allowing for more adaptable and understandable knowledge encoding in PLMs.

We believe this architectural design would pave a new direction for future research on language model pre-training, especially for LLM.

Setup

Our pretraining code is based on the NVIDIA/BERT.

First install the requirements listed there including:

  • PyTorch
  • lddl is a utility library that minimizes the friction during dataset retrieval, preprocessing and loading for the language models in NVIDIA Deep Learning Examples.
  • APEX is a PyTorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training, whereas AMP is an abbreviation used for automatic mixed precision training.
  • LAMB stands for Layerwise Adaptive Moments based optimizer, is a large batch optimization technique that helps accelerate training of deep neural networks using large minibatches.

Then install the following packages:

pip install transformers tqdm numpy datasets

Dataset

For pre-training data and pre-processing, we use this scripts to download Wikipedia.

For S2ORC pre-training and fine-tuning data, we send application form and download from here.

For PubMed pre-training data, we download from here, and Biomedical downstream dataset here.

For other datasets, we use datasets to download.

Pre-training

Here we provide the checkpoint of our model along with tokenized wikipedia knowledge base here.

The pre-training code (on 8*A100) is listed below:

cd src
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12344 --use_env run_pretraining.py \
    --data_dir ../data/phase2 \
    --output_dir ../results/pluglm \
    --tokenized_document_path ../data/tokenized_wiki.npy \
    --total_train_batch_size_per_device 8192 \
    --gradient_accumulation_steps 256 \
    --max_train_steps 8000 \
    --refreshing_steps 200 \
    --ckpt_steps 500 \
    --learning_rate=6e-3 \
    --refreshing_batch_size 4096 \
    --knowledge_attention_type multi_head

Fine-tuning

Downstream finetuning code (for example GLUE/STS-B):

WKDIR=your/work/dir
RESULTDIR=$WKDIR/results/
PT=ckpt_8000.pt
MPI="-m torch.distributed.launch --nproc_per_node 4 --master_port 12345 --use_env"
BS=batch_size
LR=learning_rate
MODEL=ckpt_path
TASK=stsb


python $MPI \
    $WKDIR/src/run_classification.py \
    --train_file $WKDIR/data/$TASK/train.jsonl \
    --dev_file $WKDIR/data/$TASK/$dev \
    --max_seq_len 128 \
    --per_device_train_batch_size $BS \
    --learning_rate $LR \
    --num_train_epochs 10 \
    --output_dir /tmp/ \
    --logging_steps -1 \
    --metrics $metrics \
    --fix_keys \
    --fix_values \
    --values_device gpu \
    --retrieval_bert \
    --ckpt_path $CKPTDIR/ckpt/$PT \
    --config_path $CKPTDIR/config.json \
    --vocab_path src/vocab/vocab \
    --keys_path $CKPTDIR/keys.npy \
    --values_path $CKPTDIR/values.npy

For other tasks, please refer to run_ner.py, run_qa.py, run_classification.py.

Citation

If you find this project helpful, please consider cite our paper.

@article{cheng2023decouple,
  title={Decouple knowledge from paramters for plug-and-play language modeling},
  author={Cheng, Xin and Lin, Yankai and Chen, Xiuying and Zhao, Dongyan and Yan, Rui},
  journal={arXiv preprint arXiv:2305.11564},
  year={2023}
}