PennyLaneAI/catalyst

Store correct JAX representation in QJIT object

erick-xanadu opened this issue · 9 comments

In this function:

def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
    # ... snip ...
    with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
        make_jaxpr_kwargs = {"static_argnums": static_argnums, "abstracted_axes": abstracted_axes}
        jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)

    # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
    jaxpr2, out_type2 = jaxpr_remove_implicit(jaxpr, out_type)
    module, context = jaxpr_to_mlir(func.__name__, jaxpr2)
    return module, context, jaxpr, out_type2, out_tree

we obtain a jaxpr representation from make_jaxpr and then we proceed to do some post-processing of it.

I think we are returning the wrong jaxpr (it should be jaxpr2) and we can replace the names to appropriately be called:

def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
    # ... snip ...
    with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
        make_jaxpr_kwargs = {"static_argnums": static_argnums, "abstracted_axes": abstracted_axes}
        jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)

    # We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
    jaxpr, out_type = jaxpr_remove_implicit(jaxpr, out_type)
    module, context = jaxpr_to_mlir(func.__name__, jaxpr)
    return module, context, jaxpr, out_type, out_tree

@grwlf Any thoughts on this?

@erick-xanadu Wouldn't the treedef (out_tree) need updating along with the abstract values (out_type)?

@dime10 I am not sure. I think out_tree should be preserved as that is what the user expects.

@grwlf Any thoughts on this?

I think there are no errors here, but different view points are possible. Consider the following program:

@qjit(keep_intermediate=True)
def fun(a):
    r = jnp.ones((a + 1,), dtype=int)
    return r

The full Jaxpr of it is

{ lambda ; a:i64[]. let
    b:i64[] = add a 1
    c:i64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1 b
  in (b, c) }  // <--------- Note two return values.

Here, b is a calculated dimension variable to be returned along with its tensor. If it was not a top-level program, we might want to use b in subsequent calculations. But since it is a top-level Jaxpr, we know that we don't need dimensions anymore, so we remove implicit outputs in order to get the desired StableHLO code. The corresponding IR is

module @fun {
  func.func public @jit_fun(%arg0: tensor<i64>) -> tensor<?xi64> attributes {llvm.emit_c_interface} {
    %0 = stablehlo.constant dense<1> : tensor<i64>
    %1 = stablehlo.add %arg0, %0 : tensor<i64>
    %2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
    %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
    %4 = stablehlo.dynamic_broadcast_in_dim %0, %3, dims = [] : (tensor<i64>, tensor<1xi32>) -> tensor<?xi64>
    return %4 : tensor<?xi64>  // <-------  Note: only one return value
  }
}

The question is: which version of Jaxpr should we call the correct one? I suggest to think of jaxpr2 as of an intermediate part of StableHLO lowering, and return jaxpr as the correct representation of the program.

Thanks @grwlf! I'd be inclined to say that jaxpr2 should be the output of the "generate jaxpr" stage, and used in all subsequent processing. Unless the jaxpr version is actually needed anything?

Another question, is out_tree (it only has one version) compatible with jaxpr or jaxpr2?

I'd be inclined to say that jaxpr2 should be the output of the "generate jaxpr" stage, and used in all subsequent processing.

Do you have some arguments for this? Do you think we can call jaxpr2 a valid Jaxpr program? I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).

Another question, is out_tree (it only has one version) compatible with jaxpr or jaxpr2?

out_tree describes the set of explicit arguments, so there is only one version shared by both jaxprs.

Note that it is not true for out_type. The original type lists implicit results while the out_type2 does not (I assume that jaxpr_remove_implicit removes the implicit part). Again, out_type2 might be even strictly-speaking incorrect: its OutDBIdx values might refer to non-existent positions in the list of outputs (one needs to double-check this).

I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).

By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have _no_cleanup_deadvars?)

I think ideally the jaxpr we produce should be valid. I understand that the moment it was decided to remove this return value, we deviated from that. Do you remember why this return value was removed?

I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).

By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have _no_cleanup_deadvars?)

Not exactly that. Consider the Jaxpr after the implicit outputs reduction is applied

{ lambda ; a:i64[]. let
    b:i64[] = add a 1
    c:i64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1 b
  in c }

Here, c has type i64[b] but there is no b variable any more in the outer scope.

Do you remember why this return value was removed?

I think it is removed solely to keep StableHLO lowering code satisfied. StableHLO does not need dimension variables so I believe (didn't look very carefully there) that its lowering code is permissively ignores them so everything keeps working.

To summarize the resolution we came to:

  • the jaxpr with implicit results will be the canonical representation at that level, in order to avoid potentially "incorrect" jaxpr in downstream applications
  • filtering implicit args from the jaxpr is a pre-processing step to the mlir lowering only
  • the out_type after filtering will be removed since it is redundant (only contains (..., True) entries)