google-deepmind/optax

Memory leak in optax.radam

lukekulik opened this issue · 2 comments

Number of live buffers (and memory use) continues to grow when using optax.radam optimizer. The issue goes away when the optimizer is switched to optax.adam.

Potentially related to #580

CUDA Version: 12.0
JAX version: 0.4.1
Optax version: 0.1.4

Repro:

import jax
import jax.numpy as jnp
import optax

from flax import linen as nn
from collections import namedtuple

client = jax.lib.xla_bridge.get_backend()

class MLP(nn.Module):
    features : namedtuple('features', ['outputsize'])
    
    def setup(self):
        self.linear = nn.Dense(self.features[0])
        self.relu = nn.activation.relu
    
    def __call__(self, x):
        return self.relu(self.linear(x)).max(axis=1)

key = jax.random.PRNGKey(0)

model = MLP(features = [10])

test_input = jax.numpy.ones((10,10,10))
test_label = jax.numpy.ones((10))

model_params = model.init(key, test_input)

optimizer = optax.radam(learning_rate=1e-3)  # this opt will overflow
# optimizer = optax.adam(learning_rate=1e-3)  # this opt will not overflow

opt_state = optimizer.init(model_params)

def mse_for_model(model):
  @jax.jit
  def mse(params, x_batched, y_batched):
    def squared_error(x, y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred, y-pred) / 2.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return mse

loss_grad_fn = jax.value_and_grad(mse_for_model(model))

for i in range(100):
    loss_val, grads = jax.block_until_ready(
        loss_grad_fn(model_params, test_input, test_label)) 
        
    bs = [b.shape for b in client.live_buffers()]
    print('num_live_buffers=', len(bs))
    
    updates, opt_state = optimizer.update(grads, opt_state) 
    params = optax.apply_updates(model_params, updates)

Hello @lukekulik,
Thanks for catching this and giving a reproducible example.
I think there is a simple fix for this: replace the jax.lax.cond in here by a jnp.where.
So that means a slight increase in computational time (the radam_update will be computed at each iteration, so a few more arithmetic operations) but should solve the memory issue.
I will put this as a good first issue since the fix should be simple. If not tackled in one week I can take care of it.

Fixed by #974