Store correct JAX representation in QJIT object
erick-xanadu opened this issue · 9 comments
def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
# ... snip ...
with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
make_jaxpr_kwargs = {"static_argnums": static_argnums, "abstracted_axes": abstracted_axes}
jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
# We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
jaxpr2, out_type2 = jaxpr_remove_implicit(jaxpr, out_type)
module, context = jaxpr_to_mlir(func.__name__, jaxpr2)
return module, context, jaxpr, out_type2, out_tree
we obtain a jaxpr representation from make_jaxpr
and then we proceed to do some post-processing of it.
I think we are returning the wrong jaxpr (it should be jaxpr2) and we can replace the names to appropriately be called:
def trace_to_mlir(func, static_argnums, abstracted_axes, *args, **kwargs):
# ... snip ...
with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
make_jaxpr_kwargs = {"static_argnums": static_argnums, "abstracted_axes": abstracted_axes}
jaxpr, out_type, out_tree = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
# We remove implicit Jaxpr result values since we are compiling a top-level jaxpr program.
jaxpr, out_type = jaxpr_remove_implicit(jaxpr, out_type)
module, context = jaxpr_to_mlir(func.__name__, jaxpr)
return module, context, jaxpr, out_type, out_tree
@erick-xanadu Wouldn't the treedef (out_tree
) need updating along with the abstract values (out_type
)?
@dime10 I am not sure. I think out_tree
should be preserved as that is what the user expects.
@grwlf Any thoughts on this?
I think there are no errors here, but different view points are possible. Consider the following program:
@qjit(keep_intermediate=True)
def fun(a):
r = jnp.ones((a + 1,), dtype=int)
return r
The full Jaxpr of it is
{ lambda ; a:i64[]. let
b:i64[] = add a 1
c:i64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1 b
in (b, c) } // <--------- Note two return values.
Here, b
is a calculated dimension variable to be returned along with its tensor. If it was not a top-level program, we might want to use b
in subsequent calculations. But since it is a top-level Jaxpr, we know that we don't need dimensions anymore, so we remove implicit outputs in order to get the desired StableHLO code. The corresponding IR is
module @fun {
func.func public @jit_fun(%arg0: tensor<i64>) -> tensor<?xi64> attributes {llvm.emit_c_interface} {
%0 = stablehlo.constant dense<1> : tensor<i64>
%1 = stablehlo.add %arg0, %0 : tensor<i64>
%2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
%3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
%4 = stablehlo.dynamic_broadcast_in_dim %0, %3, dims = [] : (tensor<i64>, tensor<1xi32>) -> tensor<?xi64>
return %4 : tensor<?xi64> // <------- Note: only one return value
}
}
The question is: which version of Jaxpr should we call the correct one? I suggest to think of jaxpr2
as of an intermediate part of StableHLO lowering, and return jaxpr
as the correct representation of the program.
Thanks @grwlf! I'd be inclined to say that jaxpr2
should be the output of the "generate jaxpr" stage, and used in all subsequent processing. Unless the jaxpr
version is actually needed anything?
Another question, is out_tree
(it only has one version) compatible with jaxpr
or jaxpr2
?
I'd be inclined to say that
jaxpr2
should be the output of the "generate jaxpr" stage, and used in all subsequent processing.
Do you have some arguments for this? Do you think we can call jaxpr2
a valid Jaxpr program? I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).
Another question, is
out_tree
(it only has one version) compatible withjaxpr
orjaxpr2
?
out_tree
describes the set of explicit arguments, so there is only one version shared by both jaxprs.
Note that it is not true for out_type
. The original type lists implicit results while the out_type2
does not (I assume that jaxpr_remove_implicit
removes the implicit part). Again, out_type2
might be even strictly-speaking incorrect: its OutDBIdx values might refer to non-existent positions in the list of outputs (one needs to double-check this).
I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).
By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have _no_cleanup_deadvars
?)
I think ideally the jaxpr we produce should be valid. I understand that the moment it was decided to remove this return value, we deviated from that. Do you remember why this return value was removed?
I am not sure: its output might contain unlisted Jaxpr variables (in tensor shapes).
By unlisted you mean that the jaxpr contains a variable that will be eliminated via dead code elimination? (Similarly to why we have
_no_cleanup_deadvars
?)
Not exactly that. Consider the Jaxpr after the implicit outputs reduction is applied
{ lambda ; a:i64[]. let
b:i64[] = add a 1
c:i64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1 b
in c }
Here, c
has type i64[b]
but there is no b
variable any more in the outer scope.
Do you remember why this return value was removed?
I think it is removed solely to keep StableHLO lowering code satisfied. StableHLO does not need dimension variables so I believe (didn't look very carefully there) that its lowering code is permissively ignores them so everything keeps working.
To summarize the resolution we came to:
- the jaxpr with implicit results will be the canonical representation at that level, in order to avoid potentially "incorrect" jaxpr in downstream applications
- filtering implicit args from the jaxpr is a pre-processing step to the mlir lowering only
- the
out_type
after filtering will be removed since it is redundant (only contains(..., True)
entries)