采用多张显卡进行GLM-4微调时报错RuntimeError
shutter-cp opened this issue · 3 comments
shutter-cp commented
0%| | 0/7224 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/jovyan/demo/chat/glm4/test.py", line 86, in <module>
trainer.train()
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 1932, in train
return inner_training_loop(
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 3307, in training_step
loss = self.compute_loss(model, inputs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/transformers/trainer.py", line 3338, in compute_loss
outputs = model(**inputs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
return self.base_model(
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1022, in forward
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1179, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/jovyan/demo/conda-home/miniconda/envs/glm4/lib/python3.10/site-packages/torch/nn/functional.py", line 3053, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:1! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)
0%| | 0/7224 [00:03<?, ?it/s]
KMnO4-zx commented
在加载模型的时候设置:device_map="auto"
应该可以避免这个问题。
shutter-cp commented
在加载模型的时候设置:device_map="auto" 应该可以避免这个问题。
使用auto就这样,device_map="cuda"不会有问题,但是不能多卡
找到原因了
修改模型文件的modeing_chatglm.py 1020行附近 可以解决设置为 auto多卡
原始:
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
修改后:
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
device = shift_labels.device
shift_logits = shift_logits.to(device)
shift_labels = shift_labels.to(device)
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
KMnO4-zx commented
👍👍👍