[Bug] MMDeepSpeedEngineWrapper bf16 bug
felixfuu opened this issue · 1 comments
felixfuu commented
Prerequisite
- I have searched Issues and Discussions but cannot get the expected help.
- The bug has not been fixed in the latest version(https://github.com/open-mmlab/mmengine).
Environment
master branch
Reproduces the problem - code sample
mmengine/mmengine/_strategy/deepspeed.py
Line 195 in e43bbb5
Reproduces the problem - command or script
deepspeed: bf16 enable
Reproduces the problem - error message
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (CUDABFloat16Type) should be the same
Additional information
When using deepspeed bf16, v.half() should change to v.to(torch.bfloat16).
mmengine/mmengine/_strategy/deepspeed.py
Line 195 in e43bbb5