google/flax

Large Difference in Loss between JAX and FLAX Two-Layer Linear Autoencoder

yCobanoglu opened this issue · 1 comments

System information

  • Linux (Manjaro):
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: 0.8.5, 0.4.30, 0.4.30
  • Python version: 3.12.4
  • GPU/TPU model and memory: Nvidia Gefore 2070, 8GB VRAM
  • CUDA version: 12.5

same problen on cpu

Problem you have encountered:

Big difference in loss convergence for Jax model vs. Flax model. Both models use optax library for loss and optimization.
I created a minimal examples which demonstrate that the loss for the jax model is going to zero much faster (for the same learning rate) than the Flax model. Loss for jax after 10 epochs "Loss: 1.0138" and flax "Loss: 1.8357".

What you expected to happen:

Convergence speed to be somewhat similar. I went through the code several times to make sure it is not a bug.

Steps to reproduce:

Colab

You can also copy the Flax model and run it on a seperate file. The issue still exists.

Code (exported from the colab)

# -*- coding: utf-8 -*-
"""Untitled0.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/10fBrXY8kEf4qZ8hz-LCp218vJOx91L9C
"""

# !pip install -q flax>=0.7.5

########################### JAX program ########################

import jax
import optax
import jax.numpy as np
from jax import random
from jax import value_and_grad
from keras.src.datasets import mnist
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

jax.config.update("jax_debug_nans", True)

input = 784
epochs = 10
latent = 256
learning_rate = 0.001
seed = 42

optimizer = optax.adam(learning_rate)

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)

X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
standardize = StandardScaler().fit(X_train)

X_train = standardize.transform(X_train)
X_test = standardize.transform(X_test)

key = random.PRNGKey(seed)
k1, k2 = random.split(key)
W1 = random.normal(k1, (input, latent)) / np.sqrt(input)
W2 = random.normal(k2, (latent, input)) / np.sqrt(latent)
params = (W1, W2)

opt_state = optimizer.init(params)


@jax.jit
def loss_fn(params, X):
    W1, W2 = params
    return optax.squared_error(X @ W1 @ W2, X).mean()


epochs = tqdm(range(epochs))
test_loss_descr = ""
for epoch in epochs:
    loss, grads = value_and_grad(loss_fn)(params, X_train)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if epoch % 10 == 0 and epoch > 0:
        test_loss = loss_fn(params, X_test)
        test_loss_descr = f" Test loss: {test_loss:.4f}"
    epochs.set_description(f"Loss: {loss:.4f}" + test_loss_descr)

############# Flax model ###########################
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
from flax.training import train_state
from jax import random
from keras.src.datasets import mnist
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

jax.config.update("jax_debug_nans", True)

input_size = 784
epochs = 10
latent = 256
learning_rate = 0.001
seed = 42

optimizer = optax.adam(learning_rate)

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)

X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
standardize = StandardScaler().fit(X_train)

X_train = standardize.transform(X_train)
X_test = standardize.transform(X_test)


def create_train_state(model, rng, learning_rate, input_size):
    """Instanciate the state outside of the training loop"""
    params = model.init(rng, jnp.ones([1, input_size]))["params"]
    opti = optax.sgd(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opti)


@jax.jit
def train_step(state, X):
    """Train for a single step"""
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, X)
        loss = optax.squared_error(logits, X).mean()
        return loss

    # Update parameters with gradient descent
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return loss, state


class Autoencoder(nn.Module):
    latent: int

    @nn.compact
    def __call__(self, x):
        x1 = nn.Dense(features=self.latent, use_bias=False)(x)
        return nn.Dense(features=x.shape[1], use_bias=False)(x1)


autoencoder = Autoencoder(latent)
init_rng = random.PRNGKey(seed)
state = create_train_state(autoencoder, init_rng, learning_rate, input_size)
del init_rng

epochs_iterator = tqdm(range(epochs))
for epoch in epochs_iterator:
    loss, state = train_step(state, X_train)
    epochs_iterator.set_description(f"Epoch: {epoch + 1}/{epochs} - Train Loss: {loss:.4f}")

Found the bug in my code. I had in train_state overwritten optax.adam with optax.sgd.