open-mmlab/mmengine

[Bug] MMDeepSpeedEngineWrapper bf16 bug

Closed this issue · 1 comments

Prerequisite

Environment

master branch

Reproduces the problem - code sample

new_inputs.append(v.half())

Reproduces the problem - command or script

deepspeed: bf16 enable

Reproduces the problem - error message

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (CUDABFloat16Type) should be the same

Additional information

When using deepspeed bf16, v.half() should change to v.to(torch.bfloat16).

new_inputs.append(v.half())

Thanks very much! We have fixed it in #1400 .