meta-llama/llama

How to fine-tune llama3.1 8b with custom dataset on multiple gpus ?

Opened this issue · 0 comments

Describe the bug

When finetune llama with custom dataset, this error occur: RuntimeError: chunk expects at least a 1-dimensional tensor
The same code works on single GPU but not works on multiple GPUs.
What is the possible reasons ?

Minimal reproducible example

import torch, multiprocessing
from datasets import load_dataset, Dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
    TrainingArguments
)
from trl import SFTTrainer
from peft.utils.other import fsdp_auto_wrap_policy
from accelerate import Accelerator
import os

accelerator = Accelerator()
set_seed(1234)
#use bf16 and FlashAttention if supported
if torch.cuda.is_bf16_supported():
    os.system('pip install flash_attn')
    compute_dtype = torch.bfloat16
    attn_implementation = 'flash_attention_2'
else:
    compute_dtype = torch.float16
    attn_implementation = 'sdpa'

# Define model and tokenizer

model_id = "meta-llama/Meta-Llama-3-8B-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation)

# Dataset mapping
------
------
------

if hasattr(accelerator.state, 'fsdp_plugin'):  
    fsdp_plugin = accelerator.state.fsdp_plugin  
    fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)  
else:  
    print("FSDP plugin is not available.")  
trainer.train()

output_dir="./fine-tuned-output"

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

Output

Currently training with a batch size of: 2
The following columns in the training set don't have a corresponding argument in `PeftModelForCausalLM.forward` and have been ignored: text. If text are not expected by `PeftModelForCausalLM.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 489
  Num Epochs = 4
  Instantaneous batch size per device = 1
  Training with DataParallel so batch size has been adjusted to: 2
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 16
  Total optimization steps = 50
  Number of trainable parameters = 41,943,040
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[42], line 8
      6 else:  
      7     print("FSDP plugin is not available.")  
----> 8 trainer.train()
     10 output_dir="./fine-tuned-output"
     12 model.save_pretrained(output_dir)

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2121         hf_hub_utils.enable_progress_bars()
   2122 else:
-> 2123     return inner_training_loop(
   2124         args=args,
   2125         resume_from_checkpoint=resume_from_checkpoint,
   2126         trial=trial,
   2127         ignore_keys_for_eval=ignore_keys_for_eval,
   2128     )

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2475 context = (
   2476     functools.partial(self.accelerator.no_sync, model=model)
   2477     if i == len(batch_samples) - 1
   2478     else contextlib.nullcontext
   2479 )
   2480 with context():
-> 2481     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2483 if (
   2484     args.logging_nan_inf_filter
   2485     and not is_torch_xla_available()
   2486     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2487 ):
   2488     # if loss is nan or inf simply add the average of previous logged losses
   2489     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3576     return loss_mb.reduce_mean().detach().to(self.args.device)
   3578 with self.compute_loss_context_manager():
-> 3579     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3581 del inputs
   3582 if (
   3583     self.args.torch_empty_cache_steps is not None
   3584     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3585 ):

File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3631         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3632     inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
   3634 # Save past state if it exists
   3635 # TODO: this needs to be fixed and made cleaner later.
   3636 if self.args.past_index >= 0:

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:175, in DataParallel.forward(self, *inputs, **kwargs)
    170     if t.device != self.src_device_obj:
    171         raise RuntimeError("module must have its parameters and buffers "
    172                            f"on device {self.src_device_obj} (device_ids[0]) but found one of "
    173                            f"them on device: {t.device}")
--> 175 inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
    176 # for forward function without any inputs, empty list and dict will be created
    177 # so the module can be executed on one device which is the first one in device_ids
    178 if not inputs and not module_kwargs:

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:197, in DataParallel.scatter(self, inputs, kwargs, device_ids)
    191 def scatter(
    192     self,
    193     inputs: Tuple[Any, ...],
    194     kwargs: Optional[Dict[str, Any]],
    195     device_ids: Sequence[Union[int, torch.device]],
    196 ) -> Any:
--> 197     return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:73, in scatter_kwargs(inputs, kwargs, target_gpus, dim)
     71 r"""Scatter with support for kwargs dictionary."""
     72 scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
---> 73 scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
     74 if len(scattered_inputs) < len(scattered_kwargs):
     75     scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs)))

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:59, in scatter(inputs, target_gpus, dim)
     53 # After scatter_map is called, a scatter_map cell will exist. This cell
     54 # has a reference to the actual function scatter_map, which has references
     55 # to a closure that has a reference to the scatter_map cell (because the
     56 # fn is recursive). To avoid this reference cycle, we set the function to
     57 # None, clearing the cell
     58 try:
---> 59     res = scatter_map(inputs)
     60 finally:
     61     scatter_map = None  # type: ignore[assignment]

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:50, in scatter.<locals>.scatter_map(obj)
     48     return [list(i) for i in zip(*map(scatter_map, obj))]
     49 if isinstance(obj, dict) and len(obj) > 0:
---> 50     return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
     51 return [obj for _ in target_gpus]

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:46, in scatter.<locals>.scatter_map(obj)
     44     return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
     45 if isinstance(obj, tuple) and len(obj) > 0:
---> 46     return list(zip(*map(scatter_map, obj)))
     47 if isinstance(obj, list) and len(obj) > 0:
     48     return [list(i) for i in zip(*map(scatter_map, obj))]

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:42, in scatter.<locals>.scatter_map(obj)
     40 def scatter_map(obj):
     41     if isinstance(obj, torch.Tensor):
---> 42         return Scatter.apply(target_gpus, None, dim, obj)
     43     if _is_namedtuple(obj):
     44         return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]

File /opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
    550 if not torch._C._are_functorch_transforms_active():
    551     # See NOTE: [functorch vjp and autograd interaction]
    552     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553     return super().apply(*args, **kwargs)  # type: ignore[misc]
    555 if not is_setup_ctx_defined:
    556     raise RuntimeError(
    557         "In order to use an autograd.Function with functorch transforms "
    558         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    559         "staticmethod. For more details, please see "
    560         "https://pytorch.org/docs/master/notes/extending.func.html"
    561     )

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:96, in Scatter.forward(ctx, target_gpus, chunk_sizes, dim, input)
     93 if torch.cuda.is_available() and ctx.input_device == -1:
     94     # Perform CPU to GPU copies in a background stream
     95     streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus]
---> 96 outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
     97 # Synchronize with the copy stream
     98 if streams is not None:

File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/comm.py:187, in scatter(tensor, devices, chunk_sizes, dim, streams, out)
    185 if out is None:
    186     devices = [_get_device_index(d) for d in devices]
--> 187     return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
    188 else:
    189     if devices is not None:

RuntimeError: chunk expects at least a 1-dimensional tensor

Runtime Environment

  • Model: [eg: llama-3.1 8b]: llama-3.1 8b
  • Using via huggingface?: [yes/no] yes
  • OS: [eg. Linux/Ubuntu, Windows]: Jupyter Lab (Ubuntu, CUDA12.0, Pytorch 2.1)
  • GPU VRAM: 151GB * 2
  • Number of GPUs: 2
  • GPU Make: [eg: Nvidia, AMD, Intel]: H200

Additional context
Add any other context about the problem or environment here.