/ParroT

The ParroT framework to enhance and regulate the Translation Abilities during Chat based on open-sourced LLMs (e.g., LLaMA-7b, Bloomz-7b1-mt) and human written translation and evaluation data.

Primary LanguagePython

ParroT

ParroT: Translating During Chat Using Large Language Models tuned with Human Translation and Feedback

paper link data link

🔥 Update

  • [2023/10/12] ParroT accepted to EMNLP 2023 (Findings)!
  • [2023/09/23] Fixed the streaming mode for local large datasets, which originally supports only datasets in Hugging Face Datasets; need to use --max_steps instead of --num_train_epochs due to the IterableDataset type.
  • [2023/07/14] Incorporated flash-attention into BLOOM for long-context training; observed about 20-30% speedup with other settings fixed.
  • [2023/06/14] Releasing detailed instruction data and scripts on @InstructMT.
  • The WMT22 test sets are made available.
  • For medium-to-small models (e.g., 7B), we recommend ZeRO2+offload rather than ZerO3; use gradient accumulation to maximize GPU usage.
  • Important optimizations: preprocess_function to be 4-5X faster; DataCollatorForSeq2Seq for batch-wise padding to save 5-10% GPU usage.
  • Introducing ParroT-LoRA which supports saving and restarting from the checkpoints (base model and lora weights) during finetuning.
  • Setting the default Transformers to >= 4.28.0.dev0 directly as it merged the PR of LLaMA. With this version on Torch 1.13.1 + CUDA 11.7, we find the finetuning process could be a bit faster (~18%) than our v1.0.0 implementation.

Highlight

ParroT

Parrots are smart birds that can respond to simple commands or questions. The question is whether they're just mimicking, or really intelligent enough to communicate with humans. This is similar to what we currently speculate about LLMs.

Promoting the good is essential, but punishing the evil is also necessary to ensure that goodness prevails. Similarly, aligning LLMs with human feedbacks is exactly to learn from correct examples and discriminate erroneous examples.

Large language models (LLMs) like ChatGPT and GPT-4 have exhibited remarkable abilities on a wide range of natural language processing (NLP) tasks, including various machine translation abilities accomplished during chat. However, these models are only accessible through restricted APIs, which creates barriers to new research and advancements in the field. Therefore, we propose the ParroT framework to enhance and regulate the translation abilities during chat based on open-sourced LLMs (e.g., LLaMA, Bloomz) and human written translation and evaluation data. Specifically, ParroT reformulates translation data into the instruction-following style, and introduces a “Hint” field for incorporating extra requirements to regulate the translation process.

LLMs-MT

Figure 1: Framework of ParroT. Hints are (optional) extra requirements to regulate the translation process.

Configurations

Datasets

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
We are translating the following sentences from Chinese to English.
    
### Input:
检查情况显示,市场销售的粮油、肉类、水果、蔬菜、蛋奶等生活必需品供应充足,商品价格基本稳定,未发现严重违法违规行为,市场经营秩序总体平稳。

### Hint: A translation with major accuracy/mistranslation errors could be

### Response:The results of the inspection indicate the sufficient supply of living necessities <v>on marketing</v> 
including cereals and oils, meat, fruits, vegetables, eggs and milk, and the basically stabilized commodity price. 
The inspection hasn’t found serious violation of laws and regulations. The market order is stable on an overall basis.

Environment

We develop ParroT based on open-sourced LLMs (e.g., LLaMA, Bloomz) with HuggingFace's transformers library.

Framework Versions:

pip install -r requirements.txt

Data Format Conversion

Convert the regular bilingual sentence pairs into Alpaca data format:

python3 scripts/convert_pair_to_alpaca.py \
    -s zh -t en \
    -if scripts/instruct_follow.txt \
    -sf data/train.zh-en.zh.txt \
    -tf data/train.zh-en.en.txt \
    -of data/train_alp.json

Convert the Alpaca data format to the training data format here:

python3 scripts/convert_alpaca_to_hf.py \
    -i data/train_alp.json \
    -o data/train_alp_hf.json

Finetune

We modify the example script of language modeling in transformers for finetuning, i.e., run_clm.py with the built in HuggingFace Trainer. So it would be easy to get started if you are familiar with run_clm.py. Also, this script supports data streaming, which might be helpful for handling larger datasets. DeepSpeed ZeRO stage 2/3 is adopted for distributed training.

The resulting finetuning scripts are named as run_clm_llms.py and run_clm_lora.py for full model training and LoRA training, respectively. Theoretically, the run_clm_lora.py script can handle both full model and LoRA by specifying the arguments. But we also keep the former one for full model in consideration of safe development.

For LoRA training, we recommend to use ZeRO2 since ZeRO3 is very unstable when saving adapter_model.bin.

For long-context training, we provide the run_clm_llms_flash.py to improve the memory efficiency.

LLaMA-7b:

  • Original weights for the LLaMA models can be obtained by filling out this Form
  • Convert the LLaMA weights into the HuggingFace format by following the instructions in this Doc
  • Optionally converted one [LLaMA-7b]

Bloomz-7b1-mt:

Example usages on 8 A100 by 1 node:

Full Model
# Multi-nodes are also supported

export NCCL_DEBUG=INFO
export NCCL_SOCKET_IFNAME=eth1
export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_NET_GDR_READ=1

export MASTER_ADDR="${CHIEF_IP:=localhost}"
export MASTER_PORT="${MASTER_PORT:=29500}"

train_path=transformers/examples/pytorch/language-modeling/run_clm_llms.py
model_path=<your_proj_path>/llama-7b
model_save=<your_proj_path>/parrot-hint-7b

# HOST_NUM will be 1
torchrun --nnodes $HOST_NUM --node_rank $INDEX --nproc_per_node 8 \
    --master_addr $MASTER_ADDR --master_port $MASTER_PORT  \
    ${train_path} \
    --deepspeed train/deepspeed_config_zero2.json \
    --model_name_or_path ${model_path} \
    --train_file data/data_parrot_hf.json \
    --preprocessing_num_workers 16 \
    --dataloader_num_workers 8 \
    --dataloader_pin_memory True \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --num_train_epochs 1.5 \
    --save_strategy "steps" \
    --save_steps 500 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --block_size 512 \
    --do_train \
    --evaluation_strategy "no" \
    --validation_split_percentage 0 \
    --fp16 True \
    --fp16_full_eval True \
    --ddp_timeout 3600 \
    --seed 1 \
    --gradient_checkpointing True \
    --output_dir ${model_save}

# Use streaming for large datasets and specify the max_steps
#    --streaming \
#    --max_steps 2500 \
LoRA
# Multi-nodes are also supported

export NCCL_DEBUG=INFO
export NCCL_SOCKET_IFNAME=eth1
export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_NET_GDR_READ=1

export MASTER_ADDR="${CHIEF_IP:=localhost}"
export MASTER_PORT="${MASTER_PORT:=29500}"

train_path=transformers/examples/pytorch/language-modeling/run_clm_lora.py
model_path=<your_proj_path>/llama-7b
model_save=<your_proj_path>/parrot-hint-lora-7b

# HOST_NUM will be 1
torchrun --nnodes $HOST_NUM --node_rank $INDEX --nproc_per_node 8 \
    --master_addr $MASTER_ADDR --master_port $MASTER_PORT  \
    ${train_path} \
    --deepspeed train/deepspeed_config_zero2.json \
    --model_name_or_path ${model_path} \
    --train_file data/data_parrot_hf.json \
    --use_lora True \
    --lora_config train/lora_config.json \
    --preprocessing_num_workers 16 \
    --dataloader_num_workers 8 \
    --dataloader_pin_memory True \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --num_train_epochs 1.5 \
    --save_strategy "steps" \
    --save_steps 500 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --block_size 512 \
    --do_train \
    --evaluation_strategy "no" \
    --validation_split_percentage 0 \
    --fp16 True \
    --fp16_full_eval True \
    --ddp_timeout 3600 \
    --seed 1 \
    --gradient_checkpointing True \
    --output_dir ${model_save}

Inference

The scripts support generation with and without hints using different instructions. The hints are appended to the default instruction with ### as a delimiter. Simply switch the inference instruction for different strategies.

  • None: instruct_inf.txt
    • Translate the following sentences from [SRC] to [TGT].
  • No Errors: instruct_inf_e2t.txt
    • Translate the following sentences from [SRC] to [TGT].###A translation with no errors could be
  • Minor Errors: instruct_inf_e2t_minor.txt
    • Translate the following sentences from [SRC] to [TGT].###A translation with minor errors could be
  • Major Errors: instruct_inf_e2t_major.txt
    • Translate the following sentences from [SRC] to [TGT].###A translation with major errors could be
  • Preferred: instruct_inf_t2t.txt
    • Translate the following sentences from [SRC] to [TGT].###We prefer to translate it to

Example usages:

Full Model
# Translation
python3 inference.py --model-name-or-path <your_proj_path>/parrot-hint-7b \
    -lp 'zh-en' \
    -t 0.1 \
    -sa 'beam' \
    -ins test/instruct_inf.txt \
    -i test/test_rand_50.zh.txt \
    -o test/test_rand_50.zh-en.none-hint.txt
    
# Text generation
python3 inference.py --model-name-or-path <your_proj_path>/parrot-hint-7b \
    -t 0.7 \
    -sa 'sample' \
    -i test/test_case.txt \
    -o test/test_case.general-task.txt
LoRA
# Translation
python3 inference_lora.py --model-name-or-path <your_proj_path>/llama-7b \
    --lora-weights <your_proj_path>/parrot-hint-lora-7b/adapter_model \
    -lp 'zh-en' \
    -t 0.1 \
    -sa 'beam' \
    -ins test/instruct_inf.txt \
    -i test/test_rand_50.zh.txt \
    -o test/test_rand_50.zh-en.none-hint.txt
    
# Text generation
python3 inference_lora.py --model-name-or-path <your_proj_path>/llama-7b \
    --lora-weights <your_proj_path>/parrot-hint-lora-7b/adapter_model \
    -t 0.7 \
    -sa 'sample' \
    -i test/test_case.txt \
    -o test/test_case.general-task.txt

MT Evaluation

We adopt two metrics, SacreBLEU and COMET (Unbabel/wmt22-comet-da), which are driven by n-gram similarity and cross-lingual pretrained models, respectively.

# SacreBLEU
cat test_rand_50.zh-en.none-hint.txt.hyp | sacrebleu -w 2 test_rand_50.en.txt

# COMET
comet-score -r test_rand_50.en.txt -s test_rand_50.zh.txt -t test_rand_50.zh-en.none-hint.txt.hyp --quiet --only_system

Finetuned LLMs and Results

Currently, we finetuned the following LLMs for ParroT with the evaluation mainly on WMT22 test sets.

  • LLaMA-7b
  • Bloomz-mt-7b
  • ParroT-LoRA

There are several interesting observations:

  • ParroT based on Bloomz-mt-7b also works well with hints. Besides, Bloomz-mt-7b shows stronger ability in the modeling of Chinese texts.
  • LoRA seems to prevent LLMs from overfitting which benefits the high-resource De-En translation but restricts the instruction learning of other directions. The limited trainable parameters (only ~4.2M) may explain this observation.
alpaca

Caption: Translation performance of LLMs on Flores subsets and WMT22 test sets.

Run LLMs on your MacBook

Try llama.cpp to run the LLMs using 4-bit quantization on a MacBook. We adopt a specific fork from comex/llama.cpp which supports the conversion of HuggingFace models to ggml format.

We recommend the use of Python 3.10.10 for convert.py since we encountered bugs with Python 3.9.5.

TypeError: 'staticmethod' object is not callable

# Clone the specific fork 
git clone --branch convert-script https://github.com/comex/llama.cpp.git
cd llama.cpp
make

# Install Python dependencies
python3 -m pip install -r requirements.txt

# Convert the 7b model to ggml fp16 format
python3 convert.py models/alpaca/pytorch_model.bin

# Quantize the model to 4-bits (using method 2 = q4_0)
./quantize models/alpaca/ggml-model-f16.bin models/alpaca/ggml-model-q4_0.bin 2 

# Run instruction mode with Alpaca
./main -m ./models/alpaca/ggml-model-q4_0.bin --color -f ./prompts/alpaca.txt -ins -b 256 --top_p 0.95 --top_k 50 --temp 0.7 --repeat_penalty 1 -t 7

Now you can talk to your own Chatbot!

Alpaca-7b
alpaca alpaca

Caption: Alpaca cannot respond to the hints.

ParroT-Hint-7b
alpaca alpaca

Caption: ParroT responds to the hints as expected.

Public Impact

Star History Chart

Acknowledgement

This project cannot be developed without the following resources:

Citation

Please kindly cite our paper if you find it helpful:

@inproceedings{jiao2023parrot,
  title={ParroT: Translating during Chat using Large Language Models tuned with Human Translation and Feedback}, 
  author={Wenxiang Jiao and Jen-tse Huang and Wenxuan Wang and Zhiwei He and Tian Liang and Xing Wang and Shuming Shi and Zhaopeng Tu},
  booktitle = {Findings of EMNLP},
  year      = {2023}
}