MCZhi/DIPP

RuntimeError

Closed this issue · 1 comments

Hello, thank you very much for your great work, but I have encountered this problem during the training period, I do not know how it is caused, could you please tell me, I would be very grateful.

Traceback (most recent call last):
File "train.py", line 254, in
model_training()
File "train.py", line 211, in model_training
val_loss, val_metrics = valid_epoch(valid_loader, predictor, planner, args.use_planning)
File "train.py", line 110, in valid_epoch
plans, predictions, scores, cost_function_weights = predictor(ego, neighbors, map_lanes, map_crosswalks)
File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/hou/DIPP/model/predictor.py", line 236, in forward
agent_agent = self.agent_agent(actors, actor_mask)
File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/hou/DIPP/model/predictor.py", line 106, in forward
output = self.interaction_net(inputs, src_key_padding_mask=mask)
File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 238, in forward
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 437, in forward
return torch._transformer_encoder_layer_fwd(
RuntimeError: Mask shape should match input shape; transformer_mask is not supported in the fallback case.

Hi, @git0929, thank you for your interest in our work. If you are using a newer version of Pytorch, please pass the argument enable_nested_tensor=False to self.interaction_net = nn.TransformerEncoder(encoder_layer, num_layers=2).