to fix bug
TinyZeaMays opened this issue · 0 comments
TinyZeaMays commented
a[label] = torch.clamp_min(- inp[label] + 1 + self.m, min=0).detach()
->
src = torch.clamp_min(
- inp.gather(dim=1, index=label.unsqueeze(1)) + 1 + self.m,
min=0,
).detach()
a.scatter_(1, label.unsqueeze(1), src)
sigma[label] = 1 - self.m
->
src = torch.ones_like(label.unsqueeze(1),
dtype=inp.dtype, device=inp.device) - self.m
sigma.scatter_(1, label.unsqueeze(1), src)