YU1ut/MixMatch-pytorch

bug

afterall204168 opened this issue · 2 comments

RuntimeError: result type Float can't be cast to the desired output type Long

has anyone found the solution to this bug? I also encountered it.

has anyone found the solution to this bug? I also encountered it.
I solved this issue by the following code (add a "if" statement).

    for param, ema_param in zip(self.params, self.ema_params):
        if ema_param.dtype==torch.float32:
            ema_param.mul_(self.alpha)
            ema_param.add_(param * one_minus_alpha)
            # customized weight decay
            param.mul_(1 - self.wd)