datawhalechina/self-llm

采用多张显卡进行GLM-4微调时报错RuntimeError

shutter-cp opened this issue · 3 comments

  0%|                                                                                                                                   | 0/7224 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/jovyan/demo/chat/glm4/test.py", line 86, in <module>
    trainer.train()
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
    return self.base_model(
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1022, in forward
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1179, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/functional.py", line 3053, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:1! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)
  0%|          | 0/7224 [00:03<?, ?it/s] 

在加载模型的时候设置:device_map="auto"
应该可以避免这个问题。

在加载模型的时候设置:device_map="auto" 应该可以避免这个问题。

使用auto就这样,device_map="cuda"不会有问题,但是不能多卡
找到原因了
修改模型文件的modeing_chatglm.py 1020行附近 可以解决设置为 auto多卡
原始:

 if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

修改后:

 if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            device = shift_labels.device
            shift_logits = shift_logits.to(device) 
            shift_labels = shift_labels.to(device)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

👍👍👍