Masking certain parameters for weight decay in adamw
AakashKumarNain opened this issue ยท 10 comments
I have a model built in Equinox, and I want to filter out parameters in a way that weight decay is applied only a certain subset of the original Pytree. But it seems that optax has a problem with pytrees passed as mask. Here is a MWE:
import jax
import equinox as eqx
class MLP(eqx.Module):
fc1: eqx.nn.Linear
fc2: eqx.nn.Linear
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.fc1 = eqx.nn.Linear(32, 64, key=key1, dtype=dtype)
self.fc2 = eqx.nn.Linear(64, 64, key=key2, dtype=dtype)
def __call__(self, x):
pass
class Attention(eqx.Module):
wqkv: eqx.nn.Linear
proj: eqx.nn.Linear
drop: eqx.nn.Dropout
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.wqkv = eqx.nn.Linear(64, 3 * 64, key=key1) # 3 for qkv
self.proj = eqx.nn.Linear(64, 64, key=key2)
self.drop = eqx.nn.Dropout()
def __call__(self, x, mask=None):
pass
class TransformerBlock(eqx.Module):
norm_1: eqx.nn.LayerNorm
norm_2: eqx.nn.LayerNorm
attn: Attention
mlp: MLP
def __init__(self, key, dtype=jnp.bfloat16):
key1, key2 = jax.random.split(key, 2)
self.norm_1 = eqx.nn.LayerNorm(64)
self.attn = Attention(key=key1, dtype=dtype)
self.norm_2 = eqx.nn.LayerNorm(64)
self.mlp = MLP(key=key2, dtype=dtype)
def __call__(self, x, mask=None):
pass
class Transformer(eqx.Module):
pos_embed: eqx.nn.Embedding
tf_blocks: TransformerBlock
norm: eqx.nn.LayerNorm
def __init__(self, key, num_layers=2, dtype=jnp.bfloat16):
keys = jax.random.split(key, num_layers + 3)
key1, key2, key3, tf_keys = keys[0], keys[1], keys[2], keys[3:]
self.tf_blocks = [TransformerBlock(tf_keys[i]) for i in range(num_layers)]
self.norm = eqx.nn.LayerNorm(64)
self.pos_embed = eqx.nn.Embedding(64, 64, key=key1)
def __call__(self, x, y, mask=None):
pos_embed = jax.vmap(self.pos_embed)(y)
def is_layer(x):
return isinstance(x, eqx.nn.Linear) or isinstance(x, eqx.nn.LayerNorm)
def set_mask(x):
if isinstance(x, eqx.nn.Linear):
return jtu.tree_map(lambda _: True, x)
elif isinstance(x, eqx.nn.LayerNorm):
mask = jtu.tree_map(lambda _: False, x)
mask = eqx.tree_at(lambda m: m.bias, mask, True)
return mask
else:
return jtu.tree_map(lambda _: False, x)
model = Transformer(jax.random.PRNGKey(1))
params = eqx.filter(model, eqx.is_array)
mask = jtu.tree_map(set_mask, params, is_leaf=is_layer)
optim = optax.adamw(learning_rate=1e-4, mask=mask)
opt_state = optim.init(params)
Traceback
---> 83 opt_state = optim.init(params)
File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in chain.<locals>.init_fn(params)
63 def init_fn(params):
---> 64 return tuple(fn(params) for fn in init_fns)
File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/combine.py:64, in <genexpr>(.0)
63 def init_fn(params):
---> 64 return tuple(fn(params) for fn in init_fns)
File ~/miniconda3/envs/jaxenv/lib/python3.11/site-packages/optax/_src/wrappers.py:544, in masked.<locals>.init_fn(params)
541 if isinstance(params, _state_utils._ParamsPlaceholder): # pylint:disable=protected-access
542 return MaskedState(inner_state=inner.init(params))
--> 544 mask_tree = mask(params) if callable(mask) else mask
545 masked_params = mask_pytree(params, mask_tree)
546 return MaskedState(inner_state=inner.init(masked_params))
TypeError: Transformer.__call__() missing 1 required positional argument: 'y'
You can find the related discussion: patrick-kidger/equinox#771
Hello @AakashKumarNain
This was raised in #913 and @JadM133 found a solution. PR #1015 should fix this.
Thanks @vroulet for pointing out the PR. I hope it gets merged soon because this has been a huge blocker for the Equinox users. Also, do you have any immediate suggestion to make it work for now?
Hello @AakashKumarNain , in the meantime, you can modify two lines in _src/wrappers.py:
Line 544: mask_tree = mask (instead of "mask(params) if callable(mask) else mask")
Line 549: mask_tree = mask (same change)
This should get your code to run as expected until the pull request is merged.
Thanks @JadM133 for the suggestion. I will try it out
The PR has been merged. You'll need to install optax locally to use it (we may not release a new version soon).
Thank you. I will do a local install
Hello @AakashKumarNain
Could you send the exact bug and a minimal reproducing example?
I've tried out the current code (see below) and don't get errors.
import equinox as eqx
import jax.numpy as jnp
import jax.random as jrd
import jax.tree_util as jtu
import jax
import optax
import optax.tree_utils as otu
# With standard pytrees
def test_mask(mask):
opt1 = optax.adamw(1., mask=mask, weight_decay=1.)
opt2 = optax.adamw(1., mask=mask, weight_decay=0.)
state1 = opt1.init(params)
state2 = opt2.init(params)
def fun(x):
return otu.tree_l2_norm(x, squared=True)
grad = jax.grad(fun)(params)
u1, _ = opt1.update(grad, state1, params)
u2, _ = opt2.update(grad, state2, params)
optax.apply_updates(params, u1)
print(f'Did mask work?: {jnp.allclose(u1[1], u2[1]) and not jnp.allclose(u1[0], u2[0])}')
params = [jnp.array([[1., 2.], [3., 4.]]), jnp.array([5., 6.])]
mask_fn = lambda p: jtu.tree_map(lambda x: x.ndim != 1, p)
mask = mask_fn(params)
test_mask(mask)
test_mask(mask_fn)
# Equinox setting
@eqx.filter_value_and_grad
def grad_loss(model, input, output):
pred = model(input)
mse = lambda x, y : jnp.mean(jnp.square(x-y))
return mse(pred, output)
@eqx.filter_jit
def make_step(input, output, model, states):
loss, grads = grad_loss(model, input, output)
u1, opt_state = optim1.update(grads, states[0], model)
u2, opt_state = optim2.update(grads, states[1], model)
is_working = jnp.allclose(u1.layers[0].bias, u2.layers[0].bias) & (~jnp.allclose(u1.layers[0].weight, u2.layers[0].weight))
jax.debug.print('Did mask work?: {}', is_working)
model = eqx.apply_updates(model, u1)
return loss, model, opt_state
key, subkey = jax.random.split(jrd.PRNGKey(0))
xs = jnp.ones((100,))
ys = jax.random.normal(key, (1,))
model = eqx.nn.MLP(xs.shape[-1], ys.shape[-1], 10, 1, key=subkey)
lr = 1e-2
filter_spec = jtu.tree_map(lambda _: True, model)
filter_spec = eqx.tree_at(
lambda tree: (tree.layers[0].bias, tree.layers[1].bias),
filter_spec,
replace=(False, False),
)
optim1 = optax.adamw(1., mask=filter_spec, weight_decay=1.)
optim2 = optax.adamw(1., mask=filter_spec, weight_decay=0.)
state1 = optim1.init(eqx.filter(model, eqx.is_inexact_array))
state2 = optim1.init(eqx.filter(model, eqx.is_inexact_array))
loss, model, opt_state = make_step(xs, ys, model, (state1, state2))
Did mask work?: True
Did mask work?: True
Did mask work?: True
Hello @vroulet , @AakashKumarNain , I think the issue is not with the mask. I assume @AakashKumarNain is getting the following error:
ValueError: You are using a transformation that requires the current value of parameters, but you are not passing `params` when calling `update`."""
The problem is that the update function of adamw is different than the others and requires param (as mentioned in the documentation). So to use adamw, using the same code as adam is changing the name of the optimizer won't work. Some changes should be done to the code as the one written by @vroulet above.