google-deepmind/optax

Masking certain parameters for weight decay in adamw

AakashKumarNain opened this issue ยท 10 comments

I have a model built in Equinox, and I want to filter out parameters in a way that weight decay is applied only a certain subset of the original Pytree. But it seems that optax has a problem with pytrees passed as mask. Here is a MWE:

import jax
import equinox as eqx

class MLP(eqx.Module):
    fc1: eqx.nn.Linear
    fc2: eqx.nn.Linear
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
        self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)

    def __call__(self, x):
       pass


class Attention(eqx.Module):
    wqkv: eqx.nn.Linear
    proj: eqx.nn.Linear
    drop: eqx.nn.Dropout
    
    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
        self.proj = eqx.nn.Linear(64, 64, key=key2)
        self.drop = eqx.nn.Dropout()

    def __call__(self, x, mask=None):
        pass


class TransformerBlock(eqx.Module):
    norm_1: eqx.nn.LayerNorm
    norm_2: eqx.nn.LayerNorm
    attn: Attention
    mlp: MLP

    def __init__(self, key, dtype=jnp.bfloat16):
        key1, key2 = jax.random.split(key, 2)
        self.norm_1 = eqx.nn.LayerNorm(64)
        self.attn = Attention(key=key1, dtype=dtype)
        self.norm_2 = eqx.nn.LayerNorm(64)
        self.mlp = MLP(key=key2, dtype=dtype)

    def __call__(self, x, mask=None):
        pass


class Transformer(eqx.Module):
    pos_embed: eqx.nn.Embedding
    tf_blocks: TransformerBlock
    norm: eqx.nn.LayerNorm

    def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
        keys = jax.random.split(key, num_layers + 3)
        key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]

        self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
        self.norm = eqx.nn.LayerNorm(64)
        self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)

    def __call__(self, x, y, mask=None):
        pos_embed = jax.vmap(self.pos_embed)(y)




def is_layer(x):
    return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)

def set_mask(x):
    if isinstance(x, eqx.nn.Linear):
        return jtu.tree_map(lambda _: True, x)
    elif isinstance(x, eqx.nn.LayerNorm):
        mask = jtu.tree_map(lambda _: False, x)
        mask = eqx.tree_at(lambda m: m.bias, mask, True)
        return mask
    else:
        return jtu.tree_map(lambda _: False, x)


model = Transformer(jax.random.PRNGKey(1))
params = eqx.filter(model, eqx.is_array)
mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)
optim = optax.adamw(learning_rate=1e-4, mask=mask)
opt_state = optim.init(params)

Traceback

---> 83 opt_state = optim.init(params)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in chain.<locals>.init_fn(params)
     63 def init_fn(params):
---> 64   return tuple(fn(params) for fn in init_fns)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in <genexpr>(.0)
     63 def init_fn(params):
---> 64   return tuple(fn(params) for fn in init_fns)

File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/wrappers.py:544, in masked.<locals>.init_fn(params)
    541 if isinstance(params, _state_utils._ParamsPlaceholder):  # pylint:disable=protected-access
    542   return MaskedState(inner_state=inner.init(params))
--> 544 mask_tree = mask(params) if callable(mask) else mask
    545 masked_params = mask_pytree(params, mask_tree)
    546 return MaskedState(inner_state=inner.init(masked_params))

TypeError: Transformer.__call__() missing 1 required positional argument: 'y'

You can find the related discussion: patrick-kidger/equinox#771

Hello @AakashKumarNain

This was raised in #913 and @JadM133 found a solution. PR #1015 should fix this.

Thanks @vroulet for pointing out the PR. I hope it gets merged soon because this has been a huge blocker for the Equinox users. Also, do you have any immediate suggestion to make it work for now?

Hello @AakashKumarNain , in the meantime, you can modify two lines in _src/wrappers.py:

Line 544: mask_tree = mask (instead of "mask(params) if callable(mask) else mask")
Line 549: mask_tree = mask (same change)

This should get your code to run as expected until the pull request is merged.

Thanks @JadM133 for the suggestion. I will try it out

The PR has been merged. You'll need to install optax locally to use it (we may not release a new version soon).

Thank you. I will do a local install

@vroulet @JadM133 I did a local install, and though masking works, it actually broke the optim.update(...) functionality for adamw. Works fine for other optimizers like adam. We should reopen this issue

Hello @AakashKumarNain

Could you send the exact bug and a minimal reproducing example?
I've tried out the current code (see below) and don't get errors.

import equinox as eqx
import jax.numpy as jnp
import jax.random as jrd
import jax.tree_util as jtu
import jax

import optax
import optax.tree_utils as otu

# With standard pytrees
def test_mask(mask):
  opt1 = optax.adamw(1., mask=mask, weight_decay=1.)
  opt2 = optax.adamw(1., mask=mask, weight_decay=0.)
  state1 = opt1.init(params)
  state2 = opt2.init(params)
  def fun(x):
    return otu.tree_l2_norm(x, squared=True)
  grad = jax.grad(fun)(params)
  u1, _ = opt1.update(grad, state1, params)
  u2, _ = opt2.update(grad, state2, params)
  optax.apply_updates(params, u1)
  print(f'Did mask work?: {jnp.allclose(u1[1], u2[1]) and not jnp.allclose(u1[0], u2[0])}')

params = [jnp.array([[1., 2.], [3., 4.]]), jnp.array([5., 6.])]
mask_fn = lambda p: jtu.tree_map(lambda x: x.ndim != 1, p)
mask = mask_fn(params)

test_mask(mask)
test_mask(mask_fn)

# Equinox setting
@eqx.filter_value_and_grad
def grad_loss(model, input, output):
    pred = model(input)
    mse = lambda x, y : jnp.mean(jnp.square(x-y))
    return mse(pred, output)

@eqx.filter_jit
def make_step(input, output, model, states):
    loss, grads = grad_loss(model, input, output)
    u1, opt_state = optim1.update(grads, states[0], model)
    u2, opt_state = optim2.update(grads, states[1], model)
    is_working = jnp.allclose(u1.layers[0].bias, u2.layers[0].bias) & (~jnp.allclose(u1.layers[0].weight, u2.layers[0].weight))
    jax.debug.print('Did mask work?: {}', is_working)
    model = eqx.apply_updates(model, u1)
    return loss, model, opt_state

key, subkey = jax.random.split(jrd.PRNGKey(0))
xs = jnp.ones((100,))
ys = jax.random.normal(key, (1,))

model = eqx.nn.MLP(xs.shape[-1], ys.shape[-1], 10, 1, key=subkey)

lr = 1e-2
filter_spec = jtu.tree_map(lambda _: True, model)
filter_spec = eqx.tree_at(
    lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
    filter_spec,
    replace=(False, False),
)

optim1 = optax.adamw(1., mask=filter_spec, weight_decay=1.)
optim2 = optax.adamw(1., mask=filter_spec, weight_decay=0.)

state1 =  optim1.init(eqx.filter(model, eqx.is_inexact_array))
state2 =  optim1.init(eqx.filter(model, eqx.is_inexact_array))

loss, model, opt_state = make_step(xs, ys, model, (state1, state2))
Did mask work?: True
Did mask work?: True
Did mask work?: True

Hello @vroulet , @AakashKumarNain , I think the issue is not with the mask. I assume @AakashKumarNain is getting the following error:

ValueError: You are using a transformation that requires the current value of parameters, but you are not passing `params` when calling `update`."""

The problem is that the update function of adamw is different than the others and requires param (as mentioned in the documentation). So to use adamw, using the same code as adam is changing the name of the optimizer won't work. Some changes should be done to the code as the one written by @vroulet above.

Yup, that's the part I missed! Thanks @vroulet @JadM133 for the help.