Checkout `jit(inline=True)`.
Closed this issue · 4 comments
Hi all,
if this works like normal function inlining in compiled languages like C
it could be really useful. What it means is that when the function is encountered instead of creating a new stack and pushing the function arguments to it before entering the function, literally copy paste the function code into place before compiling. It is often used for small functions that were created for readability, like for example radians_to_arcseconds
, which would just be a multiplication. The long story short is that I think this will guarantee performance when calling compiled functions from within compiled functions. Please note that this is a reminder to investigate and that I do not have any concrete evidence as of yet. I encountered this in the jax/jax/_src/numpy/ufuncs.py:311
code for the power
function.
Regards
Jordan
jit(inline=True)
doesn't bring any performance benefit (in a compiled context), and just generates sightly cleaner jaxpr
. Later in the compile pipeline, everything is inlined in the generated hlo
module, as can be seen from the following script
from functools import partial
import jax.numpy as jnp
import jaxlib.xla_extension as xe
from jax import jit
opt = xe.HloPrintOptions.short_parsable()
x = jnp.arange(100)
for inline in [False, True]:
@partial(jit, inline=bool(inline))
def f(x, n):
return jnp.sin(x)**n
def g(x):
for i in range(2):
x = f(x, i)
return x
print(jit(g).lower(x).compile().compiler_ir()[0].to_string(opt))
Output:
HloModule jit_g, entry_computation_layout={(s32[100]{0})->f32[100]{0}}, allow_spmd_sharding_propagation_to_output=true
ENTRY main.21 {
Arg_0.1 = s32[100]{0} parameter(0), sharding={replicated}
constant.6 = f32[] constant(0.841470957)
ROOT broadcast.4 = f32[100]{0} broadcast(constant.6), dimensions={}
}
HloModule jit_g, entry_computation_layout={(s32[100]{0})->f32[100]{0}}, allow_spmd_sharding_propagation_to_output=true
ENTRY main.11 {
Arg_0.1 = s32[100]{0} parameter(0), sharding={replicated}
constant.3 = f32[] constant(0.841470957)
ROOT broadcast.2 = f32[100]{0} broadcast(constant.3), dimensions={}
}
Ta thanks @soraros,
Since the jaxpr
is cleaner does it speed up the compile time considerably?
MWE of nicer jaxpr
.
import jax
import functools
@functools.partial(jax.jit, inline=True)
def add_one_inline(x: float) -> float:
return x + 1.
@jax.jit
def mul_two_then_add_one_inline(x: float) -> float:
return add_one_inline(2. * x)
jax.make_jaxpr(mul_two_then_add_one_inline)(jax.numpy.zeros(100))
>>> { lambda ; a:f32[100]. let
... b:f32[100] = xla_call[
... call_jaxpr={ lambda ; c:f32[100]. let
... d:f32[100] = mul 2.0 c
... e:f32[100] = add d 1.0
... in (e,) }
... name=mul_two_then_add_one
... ] a
... in (b,) }
@functools.partial(jax.jit, inline=False)
def add_one(x: float) -> float:
return x + 1.
@jax.jit
def mul_two_then_add_one(x: float) -> float:
return add_one(2. * x)
jax.make_jaxpr(mul_two_then_add_one)(jax.numpy.zeros(100))
>>> { lambda ; a:f32[100]. let
... b:f32[100] = xla_call[
... call_jaxpr={ lambda ; c:f32[100]. let
... d:f32[100] = mul 2.0 c
... e:f32[100] = xla_call[
... call_jaxpr={ lambda ; f:f32[100]. let
... g:f32[100] = add f 1.0
... in (g,) }
... name=add_one
... ] d
... in (e,) }
... name=mul_two_then_add_one
... ] a
... in (b,) }