RuntimeError: Expected all tensors to be on the same device, but found at least two devices, ....
shi-kejian opened this issue · 0 comments
Hi,
Thank you for this great effort.
I'm running into an issue with multi-gpu training. Here's my entry command.
- I'm using local data files.
- The base_training_args is the default one.
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)
python src/run.py \
src/configs/training/base_training_args.json \
--model_name_or_path facebook/bart-large \
--train_file ...\
--validation_file ...\
--test_file ...\
--input_column ...\
--input_prefix_column ... \
--output_column ...\
--overwrite_cache \
--output_dir... \
--overwrite_output_dir \
--max_source_length 1024 \
--eval_max_source_length 999999 \
--generation_max_length 640 \
--max_target_length 640 \
--max_prefix_length 96 \
--pad_prefix=True \
--do_eval=True \
--learning_rate 1e-5 \
--per_device_eval_batch_size 1 \
--per_device_train_batch_size 2 \
--unlimiformer_training=True \
--test_unlimiformer \
--eval_steps 30 --save_steps 30 \
--num_train_epochs 10 \
--metric_names rouge \
--extra_metrics bertscore \
--metric_for_best_model bertscore \
The error arises in the forward pass,
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 810, in forward
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
And the error is:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)
Could you please give some clues on where to look at for debugging? I don't think this is related to custom datasets itself. I'm aware issues could be traced to index, datastore, batching, ... The nature of this work has complexity on this, and unfortunately I really have limited knowledge.
Thank you very much!
Attached is a full stack trace:
Traceback (most recent call last):
File "unlimiformer/src/run.py", line 1183, in
main()
File "unlimiformer/src/run.py", line 803, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1539, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2654, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2679, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/_utils.py", line 693, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/ks4765/research/unlimiformer_ODMDS/src/unlimiformer.py", line 551, in pre_forward_hook
result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 1380, in forward
outputs = self.model(
^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 1248, in forward
encoder_outputs = self.encoder(
^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 810, in forward
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py", line 2235, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)