Loss mask uses torch.float32 instead of bool
Opened this issue · 1 comments
pilot7747 commented
Hi! I was implementing a custom data loader and faced a deadlock caused by difference in tensors dtype. Then, I discovered that MegatronLM uses torch.float32
instead of bool
for loss masks:
Megatron-LM/megatron/training/utils.py
Line 320 in 9de386d
It's not directly a bug but is there any reasoning behind it? It seems that using boolean masks is more logical and probably reduces the load on communication between devices.
github-actions commented
Marking as stale. No activity in 60 days.