exo-lang/exo

Exo Front-end literals and C literals type-inference at codegen

Closed this issue · 3 comments

Currently, the exo-front end only accepts python literal that are floating point values for data compute. This translates to the exact same literal in C at codegen which presents multiple issues:

  1. This potentially violates the Exo semnatics of not allowing computation on two data types of different precision. So, now the generated C code is doing mixed-precision computation.
  2. This has drastic performance implication since we emit a literal like 2.0 in C which are considered to be of type double. So, if your computation is originally signel precision, 8 bit integer, 32 bit integer, etc. C will generate very expensive conversion operations to make sure the computation is done in doubles.
  3. It is generally confusing when I write fron-end code on integers but I have to specify data literals as floating point values.

Example: The following proc:

@proc
def multiply_float_with_double_literal_store_to_double(n: size, x: f32[n], y: f64[n]):
    for i in seq(0, n):
        y[i] = x[i] * 17.0

Gets codegened to:

void multiply_float_with_double_literal_store_to_double( void *ctxt, int_fast32_t n, const float* x, double* y ) {
for (int_fast32_t i = 0; i < n; i++) {
  y[i] = (double)(x[i] * 17.0);
}
}

With the following generated code:

.L4:
        movups  (%rsi,%rax), %xmm2
        cvtps2pd        %xmm2, %xmm0
        mulpd   %xmm3, %xmm0
        movhlps %xmm2, %xmm1
        movups  %xmm0, (%rdx,%rax,2)
        cvtps2pd        %xmm1, %xmm0
        mulpd   %xmm3, %xmm0
        movups  %xmm0, 16(%rdx,%rax,2)
        addq    $16, %rax
        cmpq    %rcx, %rax
        jne     .L4

Which is converting values from single precision to double precision, then doing double precision multiplication. What you actually intended (at least according to the semantics of the original proc in Exo) is single precision multiplication, then convert the result to double precision.

On the other hand, the following code 17.0 changed to 17.0f:

void multiply_float_with_float_literal_store_to_double( void *ctxt, int_fast32_t n, float* x, double *y) {
for (int_fast32_t i = 0; i < n; i++) {
  y[i] = (double)(x[i] * 17.0f);
}
}

With the following machine code:

.L13:
        movups  (%rsi,%rax), %xmm0
        mulps   %xmm3, %xmm0
        movhlps %xmm0, %xmm1
        cvtps2pd        %xmm0, %xmm2
        movups  %xmm2, (%rdx,%rax,2)
        cvtps2pd        %xmm1, %xmm0
        movups  %xmm0, 16(%rdx,%rax,2)
        addq    $16, %rax
        cmpq    %rcx, %rax
        jne     .L13

Which is now doing the multiplication in single precision instead.

The proposed solution, keep Exo front-end and scheduling oblivious to the types of the literals involved in data computation, but at codegen perform the following pass:

  1. For each statement, scan the rhs expression to determine the type of the expression based on reads to data tensors.
  2. Then, do another pass and set the correct type for all the literals.
  3. If the rhs is just a literal, then you could give it the type of the lhs, we should just emit the literal as the user gave it from the front-end with explicity type-conversion to the lhs type.

We should also fix the front-end code to allow users to specify integer literal in data computation.

gilbo commented

Good catch @SamirDroubi !

It might be worth taking a look at our casting semantics more generally. Because we allow for arbitrarily changing precision via scheduling, it would be nice if the original algorithm could remain "precision agnostic."

The problem arises because we only infer casting currently via writing some expression (of one precision) to a buffer (with a different precision). This can potentially work to express any casting, but seems to work especially poorly for literals.

One different design approach for casting (that maybe we should revisit now) would be to allow for explicit casts to be inserted in the middle of expressions. If we did this, then it would be easier to change the precision of a literal via scheduling.

Alternately, we could avoid changing the IR that drastically by having every literal also come equipped with a "precision" tag, and add a scheduling operation that allows for twiddling it.

Maybe there are other possible ways to balance concerns here. Thoughts?

The solution suggested above does keep the compiler precision-agnostic for the most part. At codegen, we infer the type of the literals:

  1. by seeing what expressions they are being used and using the precision of that expression as the precision of the literal. This works if we keep the restrition that computation within one expression must use the same precision.
  2. If the literal isn't part of a bigger epxressions (e.g. binop, unary minus), then we infer its type as the lhs of the assignment or reduction.

I think trying to add type casting in the middle of expressions isn't that neccesary (for now). You can get around it now by calling bind_expr on that expression and setting the precision for that scalar buffer to what you want to type cast to. The only difference is that you slightly change the semantics of the program from the C standard prespective, so now you are at the mercy of the C compiler to do more optimizations (which should be fine in most cases). Although, we have been seeing C compiler have more and more trouble doing things we expect them to do as programs source code gets more complicated. In this particular case, I don't think we have been running into any needs to be able to do type casting in general and by extension we haven't seen any problems with this workaround.

So, I think for now we don't need explicit type casting.


I think one part of my post which may have made you think that I was planning on making the compiler more precision aware is that I talked about Exo front end. That was simply saying that it would be nice if allow the user to use integer literals in front-end Exo rather than always use floating point literals.

Currently, the front-end would reject this:

@proc
def foo(n: size, x: i32[n], y: i32[n]):
    for i in seq(0, n):
        y[i] = x[i] * 17

Because 17 is an integer. And you would have to write it as:

@proc
def foo(n: size, x: i32[n], y: i32[n]):
    for i in seq(0, n):
        y[i] = x[i] * 17.0

And this is acceptable by codegen as well because literals don't have a precision from the point of view of Exo.
I think this may just be a bug though...

This is actually how we noticed this literals problems. We were working on a kernel with buffers of 8 bit integers, but we were emitting operations inolving double-precision literals (because we take the literal from the front-end and emit its string representation at codegen as is). Obviously, this resulted in very slow code on the computation we were looking at than what you would expect.

Addressed by #581