syncdoth/RetNet

Initialize word embedding layer

hyunwoongko opened this issue · 7 comments

I was training RetNet model using your codebase.
But I found there's no initialization of word embedding layers.
So the loss scale was very poor. (7B model's initial loss was 3000+)
I think we need to add word embedding initialization to this codebase.

And you didn't call reset_parameters method of FFN and GLU layers.
Shouldn't we call this in constructor of the classes?
or we can utilize _init_weight method in RetNetPreTrainedModel class like the following:

if hasattr(module, "reset_parameters"):
        module.reset_parameters()

True. Feel free to use the commented out portion of the code in the _init_weights function to do embedding init.

For ffn and glu, it just calls the reset_parameters of children (nn.Linear). This is actually already called within the __init__ function of nn.Linear, so yeah, explicitly calling it would be better in terms of readability, but the output would be essentially the same.

Btw, there's a PR for a variant of this model on huggingface; So this repo is currently kinda outdated 😅 some updates incoming (bugfix, improvements, etc.) a pretrained weights is also incoming!

Thanks for kind answer.
And I think most people use this codebase for pre-training like me.
I think you want to reproduce torchscale codebase, but I suggest the following:

  • using torch.nn.init.normal_(weight, mean=0, std=std) for common nn.Linear weights
  • using torch.nn.init.normal(weight, mean=0, std=std / math.sqrt(2.0 * num_layers) for output layers (fc2 and out_proj).

This is more popular way to initialize language model parameters. (for example, Megatron-LM and Transformers use this)
Don't you mind changing weight initialization like this?
Actually this way outputs much better loss scale when pre-training stage than the initialization method in current code.

Ah, the transformers version used that way.
Thanks! I'll use the code!

@syncdoth You mentioned you'll publish paper about NucleusX,
I wonder its performance and ablation studies about hyperparameters like lr, weight initialization, batch size, etc...

Would you mind to summarize the performance of RetNet architecture compare to vanila transformer architecture and have you tried some ablation studies about hyperparamters?

Please answer this only if you feel comfortable doing so!

I got response via personal messages.
closing issue.