Memory overflow using scale_by_radam
HGangloff opened this issue · 1 comments
Hi,
I have my RAM getting used up to overflow when I use scale_by_radam
gradient transform or equivalently optax.radam
without JIT compiling the code. The problem appears on CPU and GPU but does not appear when I use JIT compilation. The problem does not seem to exist with optax.adam
.
Here is a MWE derived from optax quick start tutorial:
import random
from typing import Tuple
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # uncomment to force CPU
import optax
import jax.numpy as jnp
import jax
import numpy as np
BATCH_SIZE = 500
NUM_TRAIN_STEPS = 10000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))
TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)
initial_params = {
'hidden': jax.random.normal(shape=[8, 200], key=jax.random.PRNGKey(0)),
'hidden2': jax.random.normal(shape=[200, 100], key=jax.random.PRNGKey(0)),
'output': jax.random.normal(shape=[100, 2], key=jax.random.PRNGKey(1)),
}
def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
x = jnp.dot(x, params['hidden'])
x = jax.nn.relu(x)
x = jnp.dot(x, params['hidden2'])
x = jax.nn.relu(x)
x = jnp.dot(x, params['output'])
return x
def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
y_hat = net(batch, params)
# optax also provides a number of common loss functions.
loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)
return loss_value.mean()
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
#@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.radam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
Of course this example is simple enough and does not saturate the RAM before a long time but this issue is really problematic in another particular research project.
The problem seems to be linked with this computation specific to RAdam: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7. But I do not know how to investigate further.
Thanks for your feedback.