Facing Scaling issue on cpu (arm and x86).
Opened this issue · 1 comments
Description
i was using an custom script and it was slow as per my expectation so inspection i observed that it is not scaling properly.
jax -> 0.4.35
jaxlib -> 0.4.35
ubuntu -> 20.04
script
System info (python version, jaxlib version, accelerator, etc.)
script
"""
import jax
import jax.numpy as jnp
from jax import random
import time
def initialize_params(rng, input_size, hidden_size, output_size):
rng_hidden, rng_output = random.split(rng)
return {
"W_hidden": random.normal(rng_hidden, (input_size, hidden_size)) * jnp.sqrt(2 / input_size),
"b_hidden": jnp.zeros(hidden_size),
"W_output": random.normal(rng_output, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size),
"b_output": jnp.zeros(output_size),
}
def forward(params, x):
hidden = jnp.dot(x, params["W_hidden"]) + params["b_hidden"]
hidden = jax.nn.relu(hidden)
output = jnp.dot(hidden, params["W_output"]) + params["b_output"]
return output
batched_forward = jax.vmap(forward, in_axes=(None, 0))
if name == "main":
input_size = 512
hidden_size = 1024
output_size = 10
batch_size = 1000000
rng = random.PRNGKey(0)
params = initialize_params(rng, input_size, hidden_size, output_size)
inputs = random.normal(rng, (batch_size, input_size))
start_time = time.time()
outputs = batched_forward(params, inputs)
end_time = time.time()
print("Inference output shape:", outputs.shape)
print(f"Batched inference time for {batch_size} samples: {end_time - start_time:.4f} seconds")
"""
@dougalm, @mjwillson, @kashif, @dlwh, @d0k can anyone of you please look into it, I am stuck at this for a long time and not understanding the problem!