
Memory leak in optax.radam

Number of live buffers (and memory use) continues to grow when using optax.radam optimizer. The issue goes away when the optimizer is switched to optax.adam.

CUDA Version: 12.0
JAX version: 0.4.1
Optax version: 0.1.4


import jax
import jax.numpy as jnp
import optax

from flax import linen as nn
from collections import namedtuple

client = jax.lib.xla_bridge.get_backend()

class MLP(nn.Module):
    features : namedtuple('features', ['outputsize'])
    def setup(self):
        self.linear = nn.Dense(self.features[0])
        self.relu = nn.activation.relu
    def __call__(self, x):
        return self.relu(self.linear(x)).max(axis=1)

key = jax.random.PRNGKey(0)

model = MLP(features = [10])

test_input = jax.numpy.ones((10,10,10))
test_label = jax.numpy.ones((10))

model_params = model.init(key, test_input)

optimizer = optax.radam(learning_rate=1e-3)  # this opt will overflow
# optimizer = optax.adam(learning_rate=1e-3)  # this opt will not overflow

opt_state = optimizer.init(model_params)

def mse_for_model(model):
  def mse(params, x_batched, y_batched):
    def squared_error(x, y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred, y-pred) / 2.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return mse

loss_grad_fn = jax.value_and_grad(mse_for_model(model))

for i in range(100):
    loss_val, grads = jax.block_until_ready(
        loss_grad_fn(model_params, test_input, test_label)) 
    bs = [b.shape for b in client.live_buffers()]
    print('num_live_buffers=', len(bs))
    updates, opt_state = optimizer.update(grads, opt_state) 
    params = optax.apply_updates(model_params, updates)

Hello @lukekulik,
Thanks for catching this and giving a reproducible example.
I think there is a simple fix for this: replace the jax.lax.cond in here by a jnp.where.
So that means a slight increase in computational time (the radam_update will be computed at each iteration, so a few more arithmetic operations) but should solve the memory issue.
I will put this as a good first issue since the fix should be simple. If not tackled in one week I can take care of it.

Fixed by #974