nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__)
JunhongXu opened this issue · 3 comments
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
:flax
: 0.9.0,jax
: 0.4.30,jaxlib
: 0.4.30 - Python version: 3.11
- GPU/TPU model and memory: GPU: Nvidia RTX 4090
- CUDA version (if applicable): 12.2
Problem you have encountered:
nnx.jit(aux_fn)
is slower than directly using nnx.jit(model.__call__)
, where aux_fn
is defined by
def aux_fn(model, x):
return model(x)
From my understanding, I found that using an auxiliary function with nnx.jit
seems a common practice and is required if we want to modify the internal state of the model (#3998). However, it seems slower than directly wrapping the model.__call__
function using nnx.jit
.
See the colab link below to reproduce.
Steps to reproduce:
Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing
For completeness, I also copy the code here
import time
import jax
from flax import nnx as nnx
class MLP(nnx.Module):
def __init__(self, din: int, dout: int, rngs: nnx.Rngs) -> None:
# super().__init__()
self.fc1 = nnx.Linear(din, 128, rngs=rngs)
self.fc2 = nnx.Linear(128, 128, rngs=rngs)
self.fc3 = nnx.Linear(128, 128, rngs=rngs)
self.out = nnx.Linear(128, dout, rngs=rngs)
def __call__(self, x):
x = self.fc1(x)
x = nnx.relu(x)
x = self.fc2(x)
x = nnx.relu(x)
x = self.fc3(x)
x = nnx.relu(x)
x = self.out(x)
return x
def nn_forward(model, x):
return model, x
def benchmark_jax():
rngs = nnx.Rngs(0)
din, dout = 29, 7 # Example dimensions
mlp = MLP(din, dout, rngs)
nn_forward_call_no_aux = nnx.jit(mlp.__call__)
# Prepare data
x = jax.random.normal(rngs(), shape=(1, din))
num_iterations = 1000
warmup_iters = 100
for _ in range(warmup_iters):
_ = nn_forward_call_no_aux(x)
start_time = time.time()
for _ in range(num_iterations):
_ = nn_forward_call_no_aux(x)
end_time = time.time()
print(f"JAX forward pass time for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
print(f"JAX forward pass average time: {(end_time - start_time) / num_iterations:.5f} seconds")
print("-------------------")
nn_forward_jit = nnx.jit(nn_forward)
for _ in range(warmup_iters):
_ = nn_forward_jit(mlp, x)
start_time = time.time()
for _ in range(num_iterations):
_ = nn_forward_jit(mlp, x)
end_time = time.time()
print(f"JAX forward pass time while using auxiliary functions for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
print(f"JAX forward pass average while using auxiliary functions time: {(end_time - start_time) / num_iterations:.5f} seconds")
The outputs using a RTX 4090 are:
JAX forward pass time for 1000 iterations: 0.10531 seconds
JAX forward pass average time: 0.00011 seconds
-------------------
JAX forward pass time while using auxiliary functions for 1000 iterations: 0.59596 seconds
JAX forward pass average while using auxiliary functions time: 0.00060 seconds
mlp.__call__
is not recommended as you are passing self
as a capture. Try MLP.__call__
and passing mlp
as the first input.
Just to clarify, what is happening is that mlp.__call__
is not traversing self
so its faster, a lot faster in this case.
We are going to be developing a Rust extension (see #4196) so in the future nnx.jit
should be fast. For now consider using this pattern to remove the python overhead.