
"RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]" during reproducing mlm pre-training

Duemoo opened this issue · 5 comments

I was trying to pre-train the ATLAS model (base & large size), by running the provided example script in atlas/example_scripts/mlm/train.sh with 4 40GB A100 GPUs, but then I got this error:

Traceback (most recent call last):                                                            
  File "/home/work/atlas/atlas/train.py", line 223, in <module>                                                                                                                             
  File "/home/work/atlas/atlas/train.py", line 77, in train                                                                                                                                 
    reader_loss, retriever_loss = model(                                                      
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl                                                                     
    return forward_call(*input, **kwargs)                                                     
  File "/home/work/atlas/atlas/src/atlas.py", line 432, in forward                                                                                                                          
    passages, _ = self.retrieve(                                                              
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                                
    return func(*args, **kwargs)                                                              
  File "/home/work/atlas/atlas/src/atlas.py", line 181, in retrieve                                                                                                                         
    passages, scores = retrieve_func(*args, **kwargs)[:2]                                                                                                                                   
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                                
    return func(*args, **kwargs)                                                              
  File "/home/work/atlas/atlas/src/atlas.py", line 170, in retrieve_with_rerank                                                                                                             
    retriever_scores = torch.einsum("id, ijd->ij", [query_emb, passage_emb])                                                                                                                
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/functional.py", line 328, in einsum                                                                                 
    return einsum(equation, *_operands)                                                       
  File "/home/work/.conda/envs/atlas/lib/python3.10/site-packages/torch/functional.py", line 330, in einsum                                                                                 
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]                                                                                                                     
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 768]->[2, 1, 768] [2, 100, 1536]->[2, 100, 1536]                                           
srun: error: localhost: task 0: Exited with exit code 1                                                                                                                                     
srun: error: localhost: task 3: Exited with exit code 1                                                                                                                                     
srun: error: localhost: task 1: Exited with exit code 1                                                                                                                                     
srun: error: localhost: task 2: Exited with exit code 1 

I used the provided passages (Wikipedia Dec2018 dump), and ran the script without any changes in training arguments.
So, the batch size per device was 2 and 100 documents were retrieved by the retriever, regarding [2, 768]->[2, 1, 768] [2, 100, 1536]->[2, 100, 1536] in the error message above.
In addition, I found that the script and the overall pre-training process worked well after removing this line from the script, i.e., doing re-indexing of the whole passages instead of doing re-ranking, although this resulted in lower few-shot performance compared to the scores reported in Table.19 from the paper. (However, I think the performance issue might be irrelevant to the removal of this line)
--retrieve_with_rerank --n_to_rerank_with_retrieve_with_rerank 100 \

Could you provide any hints to solve this issue? Thank you in advance!

Hi @Duemoo, you may refer to this pull request to fix the issue for reranking.

I also cannot reproduce the 64-shot results reported in Table 2 with the pretraining/finetuning script and the settings described in the paper (without reranking in pretraining).

Thank you! I applied your commits and it fixed the problem :)

@jeffhj Thank you for the information. I guess that using CCNet indices together could be an important factor for the performance, and so I'm trying to reproduce the results using CCNet texts as well.

Thank you @jeffhj for the fix, I've merged it into master, so I am closing this issue.