Large Difference in Loss between JAX and FLAX Two-Layer Linear Autoencoder
yCobanoglu opened this issue · 1 comments
yCobanoglu commented
System information
- Linux (Manjaro):
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: 0.8.5, 0.4.30, 0.4.30 - Python version: 3.12.4
- GPU/TPU model and memory: Nvidia Gefore 2070, 8GB VRAM
- CUDA version: 12.5
same problen on cpu
Problem you have encountered:
Big difference in loss convergence for Jax model vs. Flax model. Both models use optax library for loss and optimization.
I created a minimal examples which demonstrate that the loss for the jax model is going to zero much faster (for the same learning rate) than the Flax model. Loss for jax after 10 epochs "Loss: 1.0138" and flax "Loss: 1.8357".
What you expected to happen:
Convergence speed to be somewhat similar. I went through the code several times to make sure it is not a bug.
Steps to reproduce:
You can also copy the Flax model and run it on a seperate file. The issue still exists.
Code (exported from the colab)
# -*- coding: utf-8 -*-
"""Untitled0.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/10fBrXY8kEf4qZ8hz-LCp218vJOx91L9C
"""
# !pip install -q flax>=0.7.5
########################### JAX program ########################
import jax
import optax
import jax.numpy as np
from jax import random
from jax import value_and_grad
from keras.src.datasets import mnist
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
jax.config.update("jax_debug_nans", True)
input = 784
epochs = 10
latent = 256
learning_rate = 0.001
seed = 42
optimizer = optax.adam(learning_rate)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
standardize = StandardScaler().fit(X_train)
X_train = standardize.transform(X_train)
X_test = standardize.transform(X_test)
key = random.PRNGKey(seed)
k1, k2 = random.split(key)
W1 = random.normal(k1, (input, latent)) / np.sqrt(input)
W2 = random.normal(k2, (latent, input)) / np.sqrt(latent)
params = (W1, W2)
opt_state = optimizer.init(params)
@jax.jit
def loss_fn(params, X):
W1, W2 = params
return optax.squared_error(X @ W1 @ W2, X).mean()
epochs = tqdm(range(epochs))
test_loss_descr = ""
for epoch in epochs:
loss, grads = value_and_grad(loss_fn)(params, X_train)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if epoch % 10 == 0 and epoch > 0:
test_loss = loss_fn(params, X_test)
test_loss_descr = f" Test loss: {test_loss:.4f}"
epochs.set_description(f"Loss: {loss:.4f}" + test_loss_descr)
############# Flax model ###########################
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
from flax.training import train_state
from jax import random
from keras.src.datasets import mnist
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
jax.config.update("jax_debug_nans", True)
input_size = 784
epochs = 10
latent = 256
learning_rate = 0.001
seed = 42
optimizer = optax.adam(learning_rate)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
standardize = StandardScaler().fit(X_train)
X_train = standardize.transform(X_train)
X_test = standardize.transform(X_test)
def create_train_state(model, rng, learning_rate, input_size):
"""Instanciate the state outside of the training loop"""
params = model.init(rng, jnp.ones([1, input_size]))["params"]
opti = optax.sgd(learning_rate)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opti)
@jax.jit
def train_step(state, X):
"""Train for a single step"""
def loss_fn(params):
logits = state.apply_fn({"params": params}, X)
loss = optax.squared_error(logits, X).mean()
return loss
# Update parameters with gradient descent
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return loss, state
class Autoencoder(nn.Module):
latent: int
@nn.compact
def __call__(self, x):
x1 = nn.Dense(features=self.latent, use_bias=False)(x)
return nn.Dense(features=x.shape[1], use_bias=False)(x1)
autoencoder = Autoencoder(latent)
init_rng = random.PRNGKey(seed)
state = create_train_state(autoencoder, init_rng, learning_rate, input_size)
del init_rng
epochs_iterator = tqdm(range(epochs))
for epoch in epochs_iterator:
loss, state = train_step(state, X_train)
epochs_iterator.set_description(f"Epoch: {epoch + 1}/{epochs} - Train Loss: {loss:.4f}")
yCobanoglu commented
Found the bug in my code. I had in train_state overwritten optax.adam with optax.sgd.