lucidrains/x-transformers

How to build optimizer

pfeatherstone opened this issue ยท 9 comments

Looking at

https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L263

https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/model.py#L215

https://github.com/ultralytics/ultralytics/blob/d021524e850acfa393ec25d4ecb9c4c761cca688/ultralytics/engine/trainer.py#L688

a few repositories carefully build optimizers by splitting parameters into groups, which will either experience weight decay or not. All of them agree biases of any kind don't while kernel weights from nn.Linear, nn.ConvNd do.
This repository has many kind of parameters.
My question is: where do they fall?

A shortlist of parameters I'm not sure about:

  • ScaleNorm.g
  • RMSNorm.g
  • TransformerWrapper.memory_tokens
  • Attention.mem_k and Attention.mem_v

Thank you

Currently i'm using:

def createOptimizer(model: torch.nn.Module, betas=(0.9,0.95), lr=0.001, decay=0.1):
    blacklistModules = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) + (nn.Embedding, ScaleNorm, RMSNorm)
    blacklistNames   = ["bias", "memory_tokens", 'mem_k', 'mem_v']
    decay_params   = []
    nodecay_params = []
    for module_name, module in self.named_modules():
        for param_name, param in module.named_parameters(recurse=False):
            fullname = f"{module_name}.{param_name}" if module_name else param_name
            if any(substr in fullname for substr in blacklistNames) or isinstance(module, blacklistModules):
                nodecay_params.append(param)
            else:
                decay_params.append(param)

    ndecayed            = len(decay_params)
    nnodecayed          = len(nodecay_params)
    ntotal              = len(list(filter(lambda p: p.requires_grad, self.parameters())))
    assert ndecayed + nnodecayed == ntotal, f"bad split: {ndecayed} + {nnodecayed} != {ntotal}"
    optim_groups = [
        {'params': decay_params,   'weight_decay': decay},
        {'params': nodecay_params, 'weight_decay': 0.0}
    ]
    optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
    return optimizer

I've put memory tokens in the blacklist, i.e. parameters that don't decay. Not sure if that's correct. Layers like ScaleNorm and RMSNorm I'm treating like other pytorch normalization layers, and therefore also don't decay

Basically, i've only just started playing with optimizers and found that they have a massive influence on convergence rate and stability. Duh.

Can anybody think of any other layers/parameters that shouldn't decay ?

pip install pytorch-custom-utils

from pytorch_custom_utils import get_adam_optimizer

@pfeatherstone and yeah, typically you just filter out any parameters with ndims <= 1, however, i've also heard from some researchers that it doesn't matter, ymmv

this is out of the scope for this repository though, recommend you just read some papers and decide for yourself

@pfeatherstone or hop on eleutherai and consult the crowd intelligence there. everyone has their own opinions about optimizers

@lucidrains Thank you. It looks like you are doing what nanogpt is doing. That does mean you are decaying normalization weights. I'll have a fiddle. Sorry if this is out of scope.

@pfeatherstone well, it isn't i'm doing what Karpathy is doing; we are both following an early practice for the original transformer training from Brain. however, whether it really matters, or is just passed down superstition, is still up for a future research paper to decide