Memory leak in optax.radam
lukekulik opened this issue · 2 comments
Number of live buffers (and memory use) continues to grow when using optax.radam
optimizer. The issue goes away when the optimizer is switched to optax.adam
.
Potentially related to #580
CUDA Version: 12.0
JAX version: 0.4.1
Optax version: 0.1.4
Repro:
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from collections import namedtuple
client = jax.lib.xla_bridge.get_backend()
class MLP(nn.Module):
features : namedtuple('features', ['outputsize'])
def setup(self):
self.linear = nn.Dense(self.features[0])
self.relu = nn.activation.relu
def __call__(self, x):
return self.relu(self.linear(x)).max(axis=1)
key = jax.random.PRNGKey(0)
model = MLP(features = [10])
test_input = jax.numpy.ones((10,10,10))
test_label = jax.numpy.ones((10))
model_params = model.init(key, test_input)
optimizer = optax.radam(learning_rate=1e-3) # this opt will overflow
# optimizer = optax.adam(learning_rate=1e-3) # this opt will not overflow
opt_state = optimizer.init(model_params)
def mse_for_model(model):
@jax.jit
def mse(params, x_batched, y_batched):
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y-pred, y-pred) / 2.
return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
return mse
loss_grad_fn = jax.value_and_grad(mse_for_model(model))
for i in range(100):
loss_val, grads = jax.block_until_ready(
loss_grad_fn(model_params, test_input, test_label))
bs = [b.shape for b in client.live_buffers()]
print('num_live_buffers=', len(bs))
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(model_params, updates)
Hello @lukekulik,
Thanks for catching this and giving a reproducible example.
I think there is a simple fix for this: replace the jax.lax.cond in here by a jnp.where.
So that means a slight increase in computational time (the radam_update will be computed at each iteration, so a few more arithmetic operations) but should solve the memory issue.
I will put this as a good first issue since the fix should be simple. If not tackled in one week I can take care of it.