"RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]" during reproducing mlm pre-training
Duemoo opened this issue · 5 comments
Hello,
I was trying to pre-train the ATLAS model (base & large size), by running the provided example script in atlas/example_scripts/mlm/train.sh with 4 40GB A100 GPUs, but then I got this error:
Traceback (most recent call last):
File "/home/work/atlas/atlas/train.py", line 223, in <module>
train(
File "/home/work/atlas/atlas/train.py", line 77, in train
reader_loss, retriever_loss = model(
File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/work/atlas/atlas/src/atlas.py", line 432, in forward
passages, _ = self.retrieve(
File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/work/atlas/atlas/src/atlas.py", line 181, in retrieve
passages, scores = retrieve_func(*args, **kwargs)[:2]
File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/work/atlas/atlas/src/atlas.py", line 170, in retrieve_with_rerank
retriever_scores = torch.einsum("id, ijd->ij", [query_emb, passage_emb])
File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/functional.py", line 328, in einsum
return einsum(equation, *_operands)
File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/functional.py", line 330, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 768]->[2, 1, 768] [2, 100, 1536]->[2, 100, 1536]
srun: error: localhost: task 0: Exited with exit code 1
srun: error: localhost: task 3: Exited with exit code 1
srun: error: localhost: task 1: Exited with exit code 1
srun: error: localhost: task 2: Exited with exit code 1
I used the provided passages (Wikipedia Dec2018 dump), and ran the script without any changes in training arguments.
So, the batch size per device was 2 and 100 documents were retrieved by the retriever, regarding [2, 768]->[2, 1, 768] [2, 100, 1536]->[2, 100, 1536]
in the error message above.
In addition, I found that the script and the overall pre-training process worked well after removing this line from the script, i.e., doing re-indexing of the whole passages instead of doing re-ranking, although this resulted in lower few-shot performance compared to the scores reported in Table.19 from the paper. (However, I think the performance issue might be irrelevant to the removal of this line)
--retrieve_with_rerank --n_to_rerank_with_retrieve_with_rerank 100 \
Could you provide any hints to solve this issue? Thank you in advance!
Hi @Duemoo, you may refer to this pull request to fix the issue for reranking.
I also cannot reproduce the 64-shot results reported in Table 2 with the pretraining/finetuning script and the settings described in the paper (without reranking in pretraining).
Thank you! I applied your commits and it fixed the problem :)