Was it a clerical error ? ScaleNorm.g init form dim ** -0.5. I think it should be dim ** 0.5
junphine opened this issue · 1 comments
junphine commented
class ScaleNorm(nn.Module):
def init(self, dim, eps = 1e-5):
super().init()
self.eps = eps
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
lucidrains commented
@junphine hey, thank you for catching this! indeed the sign was not correct
it should be identical to rmsnorm except it is a single learned parameter rather than the model dimension