abertsch72/unlimiformer

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, ....

shi-kejian opened this issue · 0 comments

Hi,

Thank you for this great effort.

I'm running into an issue with multi-gpu training. Here's my entry command.

  1. I'm using local data files.
  2. The base_training_args is the default one.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

python src/run.py \
    src/configs/training/base_training_args.json \
    --model_name_or_path facebook/bart-large \
    --train_file ...\
    --validation_file ...\
    --test_file ...\
    --input_column ...\
    --input_prefix_column ... \
    --output_column ...\
    --overwrite_cache \
    --output_dir... \
    --overwrite_output_dir \
    --max_source_length 1024 \
    --eval_max_source_length 999999 \
    --generation_max_length 640 \
    --max_target_length 640 \
    --max_prefix_length 96 \
    --pad_prefix=True \
    --do_eval=True \
    --learning_rate 1e-5 \
    --per_device_eval_batch_size 1 \
    --per_device_train_batch_size 2 \
    --unlimiformer_training=True \
    --test_unlimiformer \
    --eval_steps 30 --save_steps 30 \
    --num_train_epochs 10 \
    --metric_names rouge \
    --extra_metrics bertscore \
    --metric_for_best_model bertscore \

The error arises in the forward pass,
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 810, in forward
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

And the error is:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)

Could you please give some clues on where to look at for debugging? I don't think this is related to custom datasets itself. I'm aware issues could be traced to index, datastore, batching, ... The nature of this work has complexity on this, and unfortunately I really have limited knowledge.

Thank you very much!

Attached is a full stack trace:

Traceback (most recent call last):
File "unlimiformer/src/run.py", line 1183, in
main()
File "unlimiformer/src/run.py", line 803, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1539, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2654, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2679, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/_utils.py", line 693, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch/ks4765/research/unlimiformer_ODMDS/src/unlimiformer.py", line 551, in pre_forward_hook
result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 1380, in forward
outputs = self.model(
^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 1248, in forward
encoder_outputs = self.encoder(
^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 810, in forward
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 162, in forward
return F.embedding(
^^^^^^^^^^^^
File "/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py", line 2235, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)