How to build optimizer
pfeatherstone opened this issue ยท 9 comments
Looking at
https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L263
a few repositories carefully build optimizers by splitting parameters into groups, which will either experience weight decay or not. All of them agree biases of any kind don't while kernel weights from nn.Linear
, nn.ConvNd
do.
This repository has many kind of parameters.
My question is: where do they fall?
A shortlist of parameters I'm not sure about:
ScaleNorm.g
RMSNorm.g
TransformerWrapper.memory_tokens
Attention.mem_k
andAttention.mem_v
Thank you
Currently i'm using:
def createOptimizer(model: torch.nn.Module, betas=(0.9,0.95), lr=0.001, decay=0.1):
blacklistModules = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) + (nn.Embedding, ScaleNorm, RMSNorm)
blacklistNames = ["bias", "memory_tokens", 'mem_k', 'mem_v']
decay_params = []
nodecay_params = []
for module_name, module in self.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f"{module_name}.{param_name}" if module_name else param_name
if any(substr in fullname for substr in blacklistNames) or isinstance(module, blacklistModules):
nodecay_params.append(param)
else:
decay_params.append(param)
ndecayed = len(decay_params)
nnodecayed = len(nodecay_params)
ntotal = len(list(filter(lambda p: p.requires_grad, self.parameters())))
assert ndecayed + nnodecayed == ntotal, f"bad split: {ndecayed} + {nnodecayed} != {ntotal}"
optim_groups = [
{'params': decay_params, 'weight_decay': decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
return optimizer
I've put memory tokens in the blacklist, i.e. parameters that don't decay. Not sure if that's correct. Layers like ScaleNorm
and RMSNorm
I'm treating like other pytorch normalization layers, and therefore also don't decay
Basically, i've only just started playing with optimizers and found that they have a massive influence on convergence rate and stability. Duh.
Can anybody think of any other layers/parameters that shouldn't decay ?
@pfeatherstone just use https://github.com/lucidrains/pytorch-custom-utils/blob/main/pytorch_custom_utils/get_adam_optimizer.py#L15 will suit 95% of your optimizer needs
pip install pytorch-custom-utils
from pytorch_custom_utils import get_adam_optimizer
@pfeatherstone and yeah, typically you just filter out any parameters with ndims <= 1
, however, i've also heard from some researchers that it doesn't matter, ymmv
this is out of the scope for this repository though, recommend you just read some papers and decide for yourself
@pfeatherstone or hop on eleutherai and consult the crowd intelligence there. everyone has their own opinions about optimizers
@lucidrains Thank you. It looks like you are doing what nanogpt is doing. That does mean you are decaying normalization weights. I'll have a fiddle. Sorry if this is out of scope.
@pfeatherstone well, it isn't i'm doing what Karpathy is doing; we are both following an early practice for the original transformer training from Brain. however, whether it really matters, or is just passed down superstition, is still up for a future research paper to decide