lucidrains/mixture-of-experts

Error reported under FP16 training

SefaZeng opened this issue · 1 comments

I try to use MoE in the standard transformer for machine translation with fairseq codebase. And I got this error when I try to use half precision training.

Traceback (most recent call last):                                                                                                                                                      File "/data/mine/moe_test/fairseq-master_moe/fairseq_cli/train.py", line 357, in <module>
    cli_main()
  File "/data/mine/moe_test/fairseq-master_moe/fairseq_cli/train.py", line 353, in cli_main                             
    distributed_utils.call_main(args, main)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/distributed_utils.py", line 189, in call_main                    
    main(args, **kwargs)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq_cli/train.py", line 121, in main                                 
    valid_losses, should_stop = train(args, trainer, task, epoch_itr)
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/contextlib.py", line 52, in inner                               
    return func(*args, **kwds)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq_cli/train.py", line 218, in train                                
    log_output = trainer.train_step(samples)
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/contextlib.py", line 52, in inner                                   
    return func(*args, **kwds)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/trainer.py", line 457, in train_step                             
    raise e  
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/trainer.py", line 431, in train_step                             
    ignore_grad=is_dummy_batch,
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/tasks/fairseq_task.py", line 347, in train_step                        
    loss, sample_size, logging_output = criterion(model, sample)
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)  
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/criterions/label_smoothed_cross_entropy.py", line 56, in forward 
    net_output = model(**sample['net_input'])
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__    
     result = self.forward(*input, **kwargs)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/models/transformer.py", line 296, in forward                     
    src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens  
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/models/transformer.py", line 476, in forward                         
    x = layer(x, encoder_padding_mask)
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)  
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/modules/transformer_layer.py", line 186, in forward              
    x = self.moe_layer(x)
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__    
    result = self.forward(*input, **kwargs)
  File "/data/mine/moe_test/fairseq-master_moe/fairseq/modules/moe.py", line 255, in forward                                                                                          
    expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)  
  File "/data/mine/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/functional.py", line 241, in einsum         
    return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm  

It disappears when I disabled the fp16 training, so do these codes not support FP16 training? Or is there something wrong with the way I am using these codes.
Any help is appreciate. Thx!

Change the dtype to float16 or float32 works for me.