LouisDesdoigts/dLux

Checkout `jit(inline=True)`.

Closed this issue · 4 comments

Hi all,
if this works like normal function inlining in compiled languages like C it could be really useful. What it means is that when the function is encountered instead of creating a new stack and pushing the function arguments to it before entering the function, literally copy paste the function code into place before compiling. It is often used for small functions that were created for readability, like for example radians_to_arcseconds, which would just be a multiplication. The long story short is that I think this will guarantee performance when calling compiled functions from within compiled functions. Please note that this is a reminder to investigate and that I do not have any concrete evidence as of yet. I encountered this in the jax/jax/_src/numpy/ufuncs.py:311 code for the power function.
Regards
Jordan

jit(inline=True) doesn't bring any performance benefit (in a compiled context), and just generates sightly cleaner jaxpr. Later in the compile pipeline, everything is inlined in the generated hlo module, as can be seen from the following script

from functools import partial

import jax.numpy as jnp
import jaxlib.xla_extension as xe
from jax import jit

opt = xe.HloPrintOptions.short_parsable()

x = jnp.arange(100)

for inline in [False, True]:
  @partial(jit, inline=bool(inline))
  def f(x, n):
    return jnp.sin(x)**n

  def g(x):
    for i in range(2):
      x = f(x, i)
    return x

  print(jit(g).lower(x).compile().compiler_ir()[0].to_string(opt))

Output:

HloModule jit_g, entry_computation_layout={(s32[100]{0})->f32[100]{0}}, allow_spmd_sharding_propagation_to_output=true

ENTRY main.21 {
  Arg_0.1 = s32[100]{0} parameter(0), sharding={replicated}
  constant.6 = f32[] constant(0.841470957)
  ROOT broadcast.4 = f32[100]{0} broadcast(constant.6), dimensions={}
}


HloModule jit_g, entry_computation_layout={(s32[100]{0})->f32[100]{0}}, allow_spmd_sharding_propagation_to_output=true

ENTRY main.11 {
  Arg_0.1 = s32[100]{0} parameter(0), sharding={replicated}
  constant.3 = f32[] constant(0.841470957)
  ROOT broadcast.2 = f32[100]{0} broadcast(constant.3), dimensions={}
}

Ta thanks @soraros,
Since the jaxpr is cleaner does it speed up the compile time considerably?

MWE of nicer jaxpr.

import jax
import functools

@functools.partial(jax.jit, inline=True)
def add_one_inline(x: float) -> float:
    return x + 1.

@jax.jit
def mul_two_then_add_one_inline(x: float) -> float:
    return add_one_inline(2. * x)

jax.make_jaxpr(mul_two_then_add_one_inline)(jax.numpy.zeros(100))
>>> { lambda ; a:f32[100]. let
...    b:f32[100] = xla_call[
...      call_jaxpr={ lambda ; c:f32[100]. let
...          d:f32[100] = mul 2.0 c
...          e:f32[100] = add d 1.0
...        in (e,) }
...      name=mul_two_then_add_one
...    ] a
...  in (b,) }

@functools.partial(jax.jit, inline=False)
def add_one(x: float) -> float:
    return x + 1.

@jax.jit
def mul_two_then_add_one(x: float) -> float:
    return add_one(2. * x)

jax.make_jaxpr(mul_two_then_add_one)(jax.numpy.zeros(100))
>>> { lambda ; a:f32[100]. let
...     b:f32[100] = xla_call[
...       call_jaxpr={ lambda ; c:f32[100]. let
...           d:f32[100] = mul 2.0 c
...           e:f32[100] = xla_call[
...             call_jaxpr={ lambda ; f:f32[100]. let
...                 g:f32[100] = add f 1.0
...               in (g,) }
...             name=add_one
...          ] d
...         in (e,) }
...       name=mul_two_then_add_one
...     ] a
...   in (b,) }

Ta thanks @soraros,
Since the jaxpr is cleaner does it speed up the compile time considerably?

I don't have much data, but I also don't think it will make that much of a difference.