patrick-kidger/equinox

Recommended way of filtering params for weight decay

Closed this issue · 20 comments

Apologies if this has been asked before but I couldn't find any example that demonstrates this in a simple manner.

I have a model built in Equinox. Now, I want to use the AdamW optimizer where:

  1. I want to apply weight_decay to certain parameters (i.e. divide the pytree into two groups and apply weight decay to one group). For example, what if I want to apply weight decay only weights (except for normalization layer)?
  2. Perform gradient accumulation

My approach would be to define a filter spec for that set of parameters then just define the optimizer over that specific set. Similar to https://docs.kidger.site/equinox/tricks/#custom-per-parameter-behaviour I would first try something like:

import jax
import jax.numpy as jnp
import equinox as eqx
import optax

class Weight(eqx.Module):
  w: jax.Array

  def get(self):
    return self.w
  
  def __matmul__(self, o):
    return self.w @ o


model = eqx.nn.MLP(2, 'scalar', 10, 3, key=jax.random.PRNGKey(0))

is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_weights = lambda m: [x.weight
                         for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)
                         if is_linear(x)]
weighted_model = eqx.tree_at(get_weights, model, replace_fn=Weight)

weight_optimizer = optax.adamw(0.1)
other_optimizer = optax.sgd(0.1)

w_filter = lambda x: isinstance(x, Weight)
leaf = lambda x: isinstance(x, Weight)
o_filter = lambda x: not isinstance(x, Weight) and eqx.is_array(x)
w_opt_state = weight_optimizer.init(eqx.filter(weighted_model, w_filter, is_leaf=leaf))
o_opt_state = other_optimizer.init(eqx.filter(weighted_model, o_filter, is_leaf=leaf))

grads = eqx.filter_grad(lambda x, y: x(y))(weighted_model, jnp.ones(2))
updates1, _ = weight_optimizer.update(eqx.filter(grads, w_filter, is_leaf=leaf), w_opt_state, weighted_model)
updates2, _ = other_optimizer.update(eqx.filter(grads, o_filter, is_leaf=leaf), o_opt_state, weighted_model)
weighted_model = eqx.apply_updates(weighted_model, updates1)
weighted_model = eqx.apply_updates(weighted_model, updates2)

Thanks for the detailed response @lockwo A few comments though:

  1. Filtering this way can become quite messy in many cases. For example, let's say you have a Transformer with 32 TransformerLayers each of which consists of MHA, Linear and LayerNorm layers. Ideally I should filter my transformer based on the layer type as well as the shapes of the parameters.
  2. Here we are using two optimizers which is a stretch. Why? Because optax.adamw comes with a mask where you can mention the group to which weight decay won't be applied. Link to the documentation for your reference

PS: I think there should be a cleaner way of filtering params and distributing them into different groups. I don't have a clear cut answer to this right away, but it should be possible

Your points are generally true I think, I was just highlighting the general framework I would go about doing.

  1. If you know what layers you want, I think it would be pretty manageable to just define a filter spec for those layers (basically looking for those leaves then getting the weights, probably not very many LoC).
  2. Two optimizers was just to show total flexibility, but definitely you could use a mask. I actually didn't know that was there, but you might be able to just reuse the filter spec for it.

I don't know if I could write a universal thing that would definitely work for your codebase, but this approach is general. If you have a more complicated MVC, could also take a look

Gotcha! Thanks for the clarification. Also, I am not looking for a universal patter but rather a cleaner pattern.

Re complicated MVC: Yeah, let me do that because that would give everyone more clarity on what I am trying to achieve. Give me a a couple of hours

@lockwo here is a better MVC:

class MLP(eqx.Module):
    fc1: eqx.nn.Linear
    fc2: eqx.nn.Linear
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
        self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)

    def __call__(self, x):
       pass


class Attention(eqx.Module):
    wqkv: eqx.nn.Linear
    proj: eqx.nn.Linear
    drop: eqx.nn.Dropout
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
        self.proj = eqx.nn.Linear(64, 64, key=key2)
        self.drop = eqx.nn.Dropout()

    def __call__(self, x, mask=None):
        pass


class TransformerBlock(eqx.Module):
    norm_1: eqx.nn.LayerNorm
    norm_2: eqx.nn.LayerNorm
    attn: Attention
    mlp: MLP

    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.norm_1 = eqx.nn.LayerNorm(64)
        self.attn = Attention(key=key1, dtype=dtype)
        self.norm_2 = eqx.nn.LayerNorm(64)
        self.mlp = MLP(key=key2, dtype=dtype)

    def __call__(self, x, mask=None):
        pass


class Transformer(eqx.Module):
    pos_embed: eqx.nn.Embedding
    tf_blocks: TransformerBlock
    norm: eqx.nn.LayerNorm

    def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
        keys = jax.random.split(key, num_layers + 3)
        key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]

        self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
        self.norm = eqx.nn.LayerNorm(64)
        self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)

    def __call__(self, x):
        pass


model = Transformer(jax.random.PRNGKey(1))

Here, I would like to apply weight decay to all the weight param except for the normalization layers, and the biases

So in that case, it would be like

exclude = lambda x: isinstance(x, eqx.nn.LayerNorm)
leaf = lambda x: hasattr(x, "weight") and not exclude(x)
get_weights = lambda m: [x.weight for x in jax.tree_util.tree_leaves(m, is_leaf=leaf) if isinstance(x, eqx.Module)]
weighted_model = eqx.tree_at(get_weights, model, replace_fn=Weight)

and you don't need two optimizers, because now you could filter for Weight/use the mask

Perfect. Thanks a lot. Let me try this, will report my findings. BTW did you mean leaf instead of Weight in the replace_fn @lockwo ?

Looks like I still can't get it working (either I am misunderstanding something here or it is just too much of work to achieve this). Here is what I want but doesn't seem to be working for me:

def set_weights(weight, bias):
    # Set both the weights and bias to None
    # if they are to be masked from weight decay
    return None, None

def is_layer(x):
    return isinstance(x, eqx.Module)
    
def get_weights(model):
    weights = []
    biases = []
    for x in jax.tree.leaves(model, is_leaf=is_layer):
        # For each layer, check the shape of the parameters.
        # If any param is 1D, store it and return to replace it with None
        if hasattr(x, "weight") and x.weight.ndim < 2:
           weights.append(x.weight)
        if hasattr(x, "bias"):
           biases.append(x.bias)
    return weights, biases


masked_params = eqx.tree_at(get_weights, model, set_weights)

Would appreciate any help on this

Any suggestions? @patrick-kidger

I've run into this issue myself. Would it be possible to, say, extend eqx.field to add some sort of tagging system to the field metadata, which we could then use in different filters?

If I understand correctly you want a boolean mask (to pass to optax.adamw(..., mask=...)), indicating which parameters you want to apply weight decay to?

From your description it sounds like you want to apply weight decay to (a) the weight and biases of all linear layers, and (b) just the biases of all LayerNorm layers? If so then that would correspond to

params = ...

def is_layer(x):
    return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)

def set_mask(x):
    if isinstance(x, eqx.nn.Linear):
        return jtu.tree_map(lambda _: True, x)
    elif isinstance(x, eqx.nn.LayerNorm):
        mask = jtu.tree_map(lambda _: False, x)
        mask = eqx.tree_at(lambda m: m.bias, mask, True)
        return mask
    else:
        return jtu.tree_map(lambda _: False, x)

mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)

Note that here, params is whatever you have asked Optax to optimize. Typically this is just the parameters of your model, as obtained by something like params = eqx.filter(model, eqx.is_array), see also this FAQ entry.

Also note how I'm not using hasattr, or checking ndim, or anything like that. Just isinstance checks! I'd recommend this approach as usually being much more reliable.


The follow-up question on extending eqx.field. This is an idea that people like to suggest every now and again, but I'm afraid it doesn't work in general. (You can probably find some past discussions if you go looking back through this issue tracker.) There are several issues, the two most notable being:

  • The field is something owned by the surrounding class, but when you tree-map you get the individual parameters.
  • From JAX's point of view, eqx.Module is just a PyTree. Nothing more or less. Keeping this rule makes it easy to reason about Equinox code. Changing that to add special behaviour -- here some field metadata -- seems like a step in the wrong direction.

But that said, you are free to add your own metadata to eqx.field! Just like dataclasses.field. So if you want to construct something that works in your individual use-case then you are free to do so :)

Thanks @patrick-kidger for the detailed info. This is very helpful. One last query on this. In the above code, you have defined the set_mask function to set the mask for the weights and biases, but you haven't used it in the above code. Did you miss a line or two in the above code by any chance?

Typo'd. Fixed!

This still doesn't work. Here is a MWE that you can copy-paste and try:

class MLP(eqx.Module):
    fc1: eqx.nn.Linear
    fc2: eqx.nn.Linear
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
        self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)

    def __call__(self, x):
       pass


class Attention(eqx.Module):
    wqkv: eqx.nn.Linear
    proj: eqx.nn.Linear
    drop: eqx.nn.Dropout
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
        self.proj = eqx.nn.Linear(64, 64, key=key2)
        self.drop = eqx.nn.Dropout()

    def __call__(self, x, mask=None):
        pass


class TransformerBlock(eqx.Module):
    norm_1: eqx.nn.LayerNorm
    norm_2: eqx.nn.LayerNorm
    attn: Attention
    mlp: MLP

    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.norm_1 = eqx.nn.LayerNorm(64)
        self.attn = Attention(key=key1, dtype=dtype)
        self.norm_2 = eqx.nn.LayerNorm(64)
        self.mlp = MLP(key=key2, dtype=dtype)

    def __call__(self, x, mask=None):
        pass


class Transformer(eqx.Module):
    pos_embed: eqx.nn.Embedding
    tf_blocks: TransformerBlock
    norm: eqx.nn.LayerNorm

    def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
        keys = jax.random.split(key, num_layers + 3)
        key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]

        self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
        self.norm = eqx.nn.LayerNorm(64)
        self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)

    def __call__(self, x, y, mask=None):
        pos_embed = jax.vmap(self.pos_embed)(y)




def is_layer(x):
    return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)

def set_mask(x):
    if isinstance(x, eqx.nn.Linear):
        return jtu.tree_map(lambda _: True, x)
    elif isinstance(x, eqx.nn.LayerNorm):
        mask = jtu.tree_map(lambda _: False, x)
        mask = eqx.tree_at(lambda m: m.bias, mask, True)
        return mask
    else:
        return jtu.tree_map(lambda _: False, x)


model = Transformer(jax.random.PRNGKey(1))
params = eqx.filter(model, eqx.is_array)
mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)
optim = optax.adamw(learning_rate=1e-4, mask=mask)
opt_state = optim.init(params)

Traceback

---> 83 opt_state = optim.init(params)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in chain.<locals>.init_fn(params)
     63 def init_fn(params):
---> 64   return tuple(fn(params) for fn in init_fns)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in <genexpr>(.0)
     63 def init_fn(params):
---> 64   return tuple(fn(params) for fn in init_fns)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/wrappers.py:544, in masked.<locals>.init_fn(params)
    541 if isinstance(params, _state_utils._ParamsPlaceholder):  # pylint:disable=protected-access
    542   return MaskedState(inner_state=inner.init(params))
--> 544 mask_tree = mask(params) if callable(mask) else mask
    545 masked_params = mask_pytree(params, mask_tree)
    546 return MaskedState(inner_state=inner.init(masked_params))

TypeError: Transformer.__call__() missing 1 required positional argument: 'y'

Probably a known Optax oddity: they check to see if certain values are callable, and if they are, they call them. This takes precedence over the pytree-ness of an object.

Wrap your masks/gradients/etc. into a length-1 list to defeat the callable check.

Should I raise an issue on optax repo regarding this? This seems very limiting and unnecessary IMHO

Wrap your masks/gradients/etc. into a length-1 list to defeat the callable check.

This didn't work btw

@patrick-kidger IMO it would be good if you can take a look at the issue and the corresponding PR raised in Optax for this issue. They fixed it, but it broke the updates workflow of adamw w/o mask. I am suggesting this because there isn't going to be a new release of optax anytime soon, so it will be worth fixing them now

What are you proposing that Equinox does differently?

What are you proposing that Equinox does differently?

Sorry for being not very clear in my earlier comment. I am not suggesting anything to change on the Equinox side. I am asking if you can share some thoughts in thread(issue) I opened on the Optax side.