DPO+ZeRO train error
tankeui opened this issue · 2 comments
I would like to ask for your advice on the following two questions.
- DPO train does not seem to support DeepSpeed ZeRO. After manually integrating
DPOAlignerArguments
with theFinetunerArguments
class, I encountered the following issue. How can this be resolved?
$ bash run_dpo_align.sh
[2024-06-27 21:07:18,259] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:21,008] [WARNING] [runner.py:196:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2024-06-27 21:07:21,009] [INFO] [runner.py:555:main] cmd = usr/anaconda3/envs/lmflow/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=11000 --enable_each_rank_log=None ../examples/dpo_train.py --model_name_or_path usr/huggingface/hub/LLM-Research/Meta-Llama-3-70B-Instruct --dataset_path ../data/dpo-mix-7k --output_dir output_models/dpo --per_device_train_batch_size 1 --gradient_accumulation_steps 2 --max_length 512 --run_name dpo --use_lora 1 --fp16 --max_steps 200 --learning_rate 1e-6 --sanity_check True --save_aggregated_lora 0 --logging_steps 20 --deepspeed ../configs/ds_config_zero3.json
[2024-06-27 21:07:22,280] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:23,742] [INFO] [launch.py:138:main] 0 NCCL_SOCKET_IFNAME=eth2
[2024-06-27 21:07:23,742] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]}
[2024-06-27 21:07:23,742] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=8, node_rank=0
[2024-06-27 21:07:23,742] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2, 3, 4, 5, 6, 7]})
[2024-06-27 21:07:23,742] [INFO] [launch.py:163:main] dist_world_size=8
[2024-06-27 21:07:23,742] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
[2024-06-27 21:07:32,024] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,024] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,073] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,075] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,092] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,098] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,101] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-06-27 21:07:32,102] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
warnings.warn(
[2024-06-27 21:07:33,228] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,228] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,228] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,228] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,229] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,229] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,229] [INFO] [comm.py:643:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-06-27 21:07:33,229] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,229] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,229] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,229] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,229] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,229] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,229] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,229] [INFO] [comm.py:616:init_distributed] cdb=None
[2024-06-27 21:07:33,229] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-06-27 21:07:33,229] [INFO] [comm.py:616:init_distributed] cdb=None
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy
(previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False
. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[2024-06-27 21:08:25,089] [INFO] [partition_parameters.py:326:exit] finished initializing model with 70.55B parameters
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.24s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.24s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.25s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.24s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.24s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.24s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.25s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 30/30 [00:37<00:00, 1.25s/it]
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
trainable params: 16,384,000 || all params: 70,570,090,496 || trainable%: 0.0232
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 62699.81 examples/s]
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 59218.16 examples/s]
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 56548.35 examples/s]
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 54218.69 examples/s]
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 67624.98 examples/s]
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 69677.45 examples/s]
Filter: 0%| | 0/750 [00:00<?, ? examples/s]usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 73451.98 examples/s]
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 71422.40 examples/s]
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 70427.80 examples/s]
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
Filter: 0%| | 0/1000 [00:00<?, ? examples/s]usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 67080.96 examples/s]
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 72595.96 examples/s]
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 78824.50 examples/s]
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
table = cls._concat_blocks(blocks, axis=0)
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 66904.76 examples/s]
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
Filter: 100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 65951.29 examples/s]
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 83919.65 examples/s]
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
Filter: 100%|████████████████████████████████████████████████████████████████████████| 750/750 [00:00<00:00, 74188.20 examples/s]
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/transformers/training_args.py:1474: FutureWarning: evaluation_strategy
is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use eval_strategy
instead
warnings.warn(
usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in 'init': max_prompt_length, max_length. Will not be supported from version '1.0.0'.
Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
[rank0]: Traceback (most recent call last):
[rank0]: File "usr/LMFlow/scripts/../examples/dpo_train.py", line 53, in
[rank0]: aligned_model = aligner.align(
[rank0]: File "usr/LMFlow/src/lmflow/pipeline/dpo_aligner.py", line 157, in align
[rank0]: dpo_trainer = self._initialize_trainer(model, tokenizer)
[rank0]: File "usr/LMFlow/src/lmflow/pipeline/dpo_aligner.py", line 116, in _initialize_trainer
[rank0]: dpo_trainer = DPOTrainer(
[rank0]: File "usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
[rank0]: return f(*args, **kwargs)
[rank0]: File "usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 241, in init
[rank0]: model = model.merge_and_unload()
[rank0]: File "usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 838, in merge_and_unload
[rank0]: return self._unload_and_optionally_merge(
[rank0]: File "usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 457, in _unload_and_optionally_merge
[rank0]: target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
[rank0]: File "usr/anaconda3/envs/lmflow/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 472, in merge
[rank0]: base_layer.weight.data = base_layer.weight.data + delta_weight
[rank0]: RuntimeError: The size of tensor a (0) must match the size of tensor b (8192) at non-singleton dimension 1
- Is LISA currently supported for multi-GPU training?
Hi, thanks for your interest in LMFlow!
- It seems to be an issue with sharding under parallel setting, and unfortunately we don't have any idea on that currently. Sorry for the inconvenience 🙏 . PR #867 will provide a dpo with ZeRO settings, and this PR is now waiting for the final review before merge. Feel free to switch to
yizhenjia-dpo-v2
branch to try that first, and the shell command is here: run_dpov2_align.sh.
Dataset for dpo v2:
- Scored (will sample to paired dataset according to
--sampling_paired_method
and calculate the margin based onscore
)
// This kind of dataset is commonly used in reward model training/prediction, as well as rl training.
{
"type": "text_to_scored_textlist",
"instances": [
{
"input": "what's your name?",
"output": [
{"score": 1.0, "text": "My name is John"},
{"score": -0.8, "text": "I'm John"},
{"score": -1.5, "text": "I'm John. Who are you?"}
]
},
{
"input": "Who are you?",
"output": [
{"score": 1.5, "text": "My name is Amy"},
{"score": 1.0, "text": "I'm Amy"}
]
},
]
}
- Paired (all fields are required)
//This kind of dataset is commonly used in reward model training as well as rl training.
{
"type": "paired_text_to_text",
"instances": [
{
"prompt": "Who are you?",
"chosen": "My name is Amy.",
"rejected": "I'm Amy",
"margin": 0.6
},
{
"prompt": "what's your name?",
"chosen": "My name is John.",
"rejected": "I'm John",
"margin": 0.5
}
]
}
Note that user have to add conversation template manually if needed (for example, change prompt
to <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nwhat's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n
and chosen
to My name is Amy.<|eot_id|>
). Auto templating will come soon in the next PRs.
- Unfortunately, LISA doesn't support multi-GPU training currently. Please stay tuned! Our roadmap: #862
Thanks! It works!