How to fine-tune llama3.1 8b with custom dataset on multiple gpus ?
Opened this issue · 0 comments
dataai1205 commented
Describe the bug
When finetune llama with custom dataset, this error occur: RuntimeError: chunk expects at least a 1-dimensional tensor
The same code works on single GPU but not works on multiple GPUs.
What is the possible reasons ?
Minimal reproducible example
import torch, multiprocessing
from datasets import load_dataset, Dataset
from peft import LoraConfig, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
set_seed,
TrainingArguments
)
from trl import SFTTrainer
from peft.utils.other import fsdp_auto_wrap_policy
from accelerate import Accelerator
import os
accelerator = Accelerator()
set_seed(1234)
#use bf16 and FlashAttention if supported
if torch.cuda.is_bf16_supported():
os.system('pip install flash_attn')
compute_dtype = torch.bfloat16
attn_implementation = 'flash_attention_2'
else:
compute_dtype = torch.float16
attn_implementation = 'sdpa'
# Define model and tokenizer
model_id = "meta-llama/Meta-Llama-3-8B-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation)
# Dataset mapping
------
------
------
if hasattr(accelerator.state, 'fsdp_plugin'):
fsdp_plugin = accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
else:
print("FSDP plugin is not available.")
trainer.train()
output_dir="./fine-tuned-output"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
Output
Currently training with a batch size of: 2
The following columns in the training set don't have a corresponding argument in `PeftModelForCausalLM.forward` and have been ignored: text. If text are not expected by `PeftModelForCausalLM.forward`, you can safely ignore this message.
***** Running training *****
Num examples = 489
Num Epochs = 4
Instantaneous batch size per device = 1
Training with DataParallel so batch size has been adjusted to: 2
Total train batch size (w. parallel, distributed & accumulation) = 32
Gradient Accumulation steps = 16
Total optimization steps = 50
Number of trainable parameters = 41,943,040
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[42], line 8
6 else:
7 print("FSDP plugin is not available.")
----> 8 trainer.train()
10 output_dir="./fine-tuned-output"
12 model.save_pretrained(output_dir)
File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2121 hf_hub_utils.enable_progress_bars()
2122 else:
-> 2123 return inner_training_loop(
2124 args=args,
2125 resume_from_checkpoint=resume_from_checkpoint,
2126 trial=trial,
2127 ignore_keys_for_eval=ignore_keys_for_eval,
2128 )
File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2475 context = (
2476 functools.partial(self.accelerator.no_sync, model=model)
2477 if i == len(batch_samples) - 1
2478 else contextlib.nullcontext
2479 )
2480 with context():
-> 2481 tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
2483 if (
2484 args.logging_nan_inf_filter
2485 and not is_torch_xla_available()
2486 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2487 ):
2488 # if loss is nan or inf simply add the average of previous logged losses
2489 tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
3576 return loss_mb.reduce_mean().detach().to(self.args.device)
3578 with self.compute_loss_context_manager():
-> 3579 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
3581 del inputs
3582 if (
3583 self.args.torch_empty_cache_steps is not None
3584 and self.state.global_step % self.args.torch_empty_cache_steps == 0
3585 ):
File /opt/conda/lib/python3.11/site-packages/transformers/trainer.py:3633, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3631 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3632 inputs = {**inputs, **loss_kwargs}
-> 3633 outputs = model(**inputs)
3634 # Save past state if it exists
3635 # TODO: this needs to be fixed and made cleaner later.
3636 if self.args.past_index >= 0:
File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:175, in DataParallel.forward(self, *inputs, **kwargs)
170 if t.device != self.src_device_obj:
171 raise RuntimeError("module must have its parameters and buffers "
172 f"on device {self.src_device_obj} (device_ids[0]) but found one of "
173 f"them on device: {t.device}")
--> 175 inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
176 # for forward function without any inputs, empty list and dict will be created
177 # so the module can be executed on one device which is the first one in device_ids
178 if not inputs and not module_kwargs:
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py:197, in DataParallel.scatter(self, inputs, kwargs, device_ids)
191 def scatter(
192 self,
193 inputs: Tuple[Any, ...],
194 kwargs: Optional[Dict[str, Any]],
195 device_ids: Sequence[Union[int, torch.device]],
196 ) -> Any:
--> 197 return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:73, in scatter_kwargs(inputs, kwargs, target_gpus, dim)
71 r"""Scatter with support for kwargs dictionary."""
72 scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
---> 73 scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
74 if len(scattered_inputs) < len(scattered_kwargs):
75 scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs)))
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:59, in scatter(inputs, target_gpus, dim)
53 # After scatter_map is called, a scatter_map cell will exist. This cell
54 # has a reference to the actual function scatter_map, which has references
55 # to a closure that has a reference to the scatter_map cell (because the
56 # fn is recursive). To avoid this reference cycle, we set the function to
57 # None, clearing the cell
58 try:
---> 59 res = scatter_map(inputs)
60 finally:
61 scatter_map = None # type: ignore[assignment]
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:50, in scatter.<locals>.scatter_map(obj)
48 return [list(i) for i in zip(*map(scatter_map, obj))]
49 if isinstance(obj, dict) and len(obj) > 0:
---> 50 return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
51 return [obj for _ in target_gpus]
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:46, in scatter.<locals>.scatter_map(obj)
44 return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
45 if isinstance(obj, tuple) and len(obj) > 0:
---> 46 return list(zip(*map(scatter_map, obj)))
47 if isinstance(obj, list) and len(obj) > 0:
48 return [list(i) for i in zip(*map(scatter_map, obj))]
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py:42, in scatter.<locals>.scatter_map(obj)
40 def scatter_map(obj):
41 if isinstance(obj, torch.Tensor):
---> 42 return Scatter.apply(target_gpus, None, dim, obj)
43 if _is_namedtuple(obj):
44 return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
File /opt/conda/lib/python3.11/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
550 if not torch._C._are_functorch_transforms_active():
551 # See NOTE: [functorch vjp and autograd interaction]
552 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553 return super().apply(*args, **kwargs) # type: ignore[misc]
555 if not is_setup_ctx_defined:
556 raise RuntimeError(
557 "In order to use an autograd.Function with functorch transforms "
558 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
559 "staticmethod. For more details, please see "
560 "https://pytorch.org/docs/master/notes/extending.func.html"
561 )
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:96, in Scatter.forward(ctx, target_gpus, chunk_sizes, dim, input)
93 if torch.cuda.is_available() and ctx.input_device == -1:
94 # Perform CPU to GPU copies in a background stream
95 streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus]
---> 96 outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
97 # Synchronize with the copy stream
98 if streams is not None:
File /opt/conda/lib/python3.11/site-packages/torch/nn/parallel/comm.py:187, in scatter(tensor, devices, chunk_sizes, dim, streams, out)
185 if out is None:
186 devices = [_get_device_index(d) for d in devices]
--> 187 return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
188 else:
189 if devices is not None:
RuntimeError: chunk expects at least a 1-dimensional tensor
Runtime Environment
- Model: [eg:
llama-3.1 8b
]: llama-3.1 8b - Using via huggingface?: [yes/no] yes
- OS: [eg. Linux/Ubuntu, Windows]: Jupyter Lab (Ubuntu, CUDA12.0, Pytorch 2.1)
- GPU VRAM: 151GB * 2
- Number of GPUs: 2
- GPU Make: [eg: Nvidia, AMD, Intel]: H200
Additional context
Add any other context about the problem or environment here.