lucidrains/PaLM-rlhf-pytorch

norm.gamma not used during backprop

conceptofmind opened this issue · 2 comments

Hi @lucidrains ,

I am almost ready to deploy the distributed training run. One thing I noticed is that norm.gamma is an unused parameter.

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

This throws an error during distributed training.

This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 

Find unused parameters:

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)

    for name, param in model.named_parameters():
        if param.grad is None:
            print("NONE")
            print(name)

Output:

NONE
norm.gamma

This is resolved by setting find_unused_parameters=True at the cost of double forward.

I was wondering if you had any idea why this may be the case or if there is a proper way to resolve this issue.

I greatly appreciate your input as always.

Thank you,

Enrico

@dmahan93 noticed that embeds are not fed to logits. This may be the issue.

Logits takes in x:

        # final norm

        embeds = self.norm(x)

        if return_only_embedding:
            return embeds

        # to logits

        logits = self.to_logits(x)

Should it be logits takes in embeds?

        # final norm

        embeds = self.norm(x)

        if return_only_embedding:
            return embeds

        # to logits

        logits = self.to_logits(embeds)

Thank you,

Enrico

@conceptofmind @dmahan93 oh yes, thanks for catching this! put in a quick fix