google/flax

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__)

JunhongXu opened this issue · 3 comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax: 0.9.0, jax: 0.4.30, jaxlib: 0.4.30
  • Python version: 3.11
  • GPU/TPU model and memory: GPU: Nvidia RTX 4090
  • CUDA version (if applicable): 12.2

Problem you have encountered:

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__), where aux_fn is defined by

def aux_fn(model, x):
    return model(x)

From my understanding, I found that using an auxiliary function with nnx.jit seems a common practice and is required if we want to modify the internal state of the model (#3998). However, it seems slower than directly wrapping the model.__call__ function using nnx.jit.

See the colab link below to reproduce.

Steps to reproduce:

Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing

For completeness, I also copy the code here

import time
import jax
from flax import nnx as nnx


class MLP(nnx.Module):
	def __init__(self, din: int, dout: int, rngs: nnx.Rngs) -> None:
		# super().__init__()
		self.fc1 = nnx.Linear(din, 128, rngs=rngs)
		self.fc2 = nnx.Linear(128, 128, rngs=rngs)
		self.fc3 = nnx.Linear(128, 128, rngs=rngs)
		self.out = nnx.Linear(128, dout, rngs=rngs)

	def __call__(self, x):
		x = self.fc1(x)
		x = nnx.relu(x)
		x = self.fc2(x)
		x = nnx.relu(x)
		x = self.fc3(x)
		x = nnx.relu(x)
		x = self.out(x)
		return x


def nn_forward(model, x):
    return model, x


def benchmark_jax():
    rngs = nnx.Rngs(0)
    din, dout = 29, 7  # Example dimensions
    mlp = MLP(din, dout, rngs)
    nn_forward_call_no_aux = nnx.jit(mlp.__call__)

    # Prepare data
    x = jax.random.normal(rngs(), shape=(1, din))
    num_iterations = 1000
    warmup_iters = 100

    for _ in range(warmup_iters):
        _ = nn_forward_call_no_aux(x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_call_no_aux(x)
    end_time = time.time()

    print(f"JAX forward pass time for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average time: {(end_time - start_time) / num_iterations:.5f} seconds")

    print("-------------------")
    nn_forward_jit = nnx.jit(nn_forward)
    for _ in range(warmup_iters):
        _ = nn_forward_jit(mlp, x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_jit(mlp, x)
    end_time = time.time()
    print(f"JAX forward pass time while using auxiliary functions for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average while using auxiliary functions time: {(end_time - start_time) / num_iterations:.5f} seconds")

The outputs using a RTX 4090 are:

JAX forward pass time for 1000 iterations: 0.10531 seconds
JAX forward pass average time: 0.00011 seconds
-------------------
JAX forward pass time while using auxiliary functions for 1000 iterations: 0.59596 seconds
JAX forward pass average while using auxiliary functions time: 0.00060 seconds

mlp.__call__ is not recommended as you are passing self as a capture. Try MLP.__call__ and passing mlp as the first input.

Just to clarify, what is happening is that mlp.__call__ is not traversing self so its faster, a lot faster in this case.
We are going to be developing a Rust extension (see #4196) so in the future nnx.jit should be fast. For now consider using this pattern to remove the python overhead.

I've created this mini guide to clarify the situation around performance: #4224.