bug
afterall204168 opened this issue · 2 comments
afterall204168 commented
RuntimeError: result type Float can't be cast to the desired output type Long
kevinghst commented
has anyone found the solution to this bug? I also encountered it.
afterall204168 commented
has anyone found the solution to this bug? I also encountered it.
I solved this issue by the following code (add a "if" statement).
for param, ema_param in zip(self.params, self.ema_params):
if ema_param.dtype==torch.float32:
ema_param.mul_(self.alpha)
ema_param.add_(param * one_minus_alpha)
# customized weight decay
param.mul_(1 - self.wd)