Branch:pytorch1.6 RuntimeError:expected scalar type Float but found Half
ShellyLingling opened this issue · 1 comments
ShellyLingling commented
Please help me!
pytorch version:1.6.0
when I use the @torch.cuda.amp.autocast before "def forward():", this error occurs.
I don't know how how to solve this problem.
I couldn't find a substitute for @apex.amp.float_function in pytorch1.6
lbin commented
you should install apex from nvidia github repo