HigherOrderCO/Bend

Performance degradation on Bitonic Sort output since release

VictorTaelin opened this issue · 8 comments

Reproducing the behavior

On release, the Bitonic Sort example:

def gen(d, x):
  switch d:
    case 0:
      return x
    case _:
      return (gen(d-1, x * 2 + 1), gen(d-1, x * 2))

def sum(d, t):
  switch d:
    case 0:
      return t
    case _:
      (t.a, t.b) = t
      return sum(d-1, t.a) + sum(d-1, t.b)

def swap(s, a, b):
  switch s:
    case 0:
      return (a,b)
    case _:
      return (b,a)

def warp(d, s, a, b):
  switch d:
    case 0:
      return swap(s ^ (a > b), a, b)
    case _:
      (a.a,a.b) = a
      (b.a,b.b) = b
      (A.a,A.b) = warp(d-1, s, a.a, b.a)
      (B.a,B.b) = warp(d-1, s, a.b, b.b)
      return ((A.a,B.a),(A.b,B.b))

def flow(d, s, t):
  switch d:
    case 0:
      return t
    case _:
      (t.a, t.b) = t
      return down(d, s, warp(d-1, s, t.a, t.b))

def down(d,s,t):
  switch d:
    case 0:
      return t
    case _:
      (t.a, t.b) = t
      return (flow(d-1, s, t.a), flow(d-1, s, t.b))

def sort(d, s, t):
  switch d:
    case 0:
      return t
    case _:
      (t.a, t.b) = t
      return flow(d, s, (sort(d-1, 0, t.a), sort(d-1, 1, t.b)))

def main:
  return sum(20, sort(20, 0, gen(20, 0)))

Generated the following HVM output:

@down = (?(((a (* a)) @down__C0) (b (c d))) (c (b d)))

@down__C0 = ({a e} ((c g) ({b f} (d h))))
  &! @flow ~ (a (b (c d)))
  &! @flow ~ (e (f (g h)))

@flow = (?(((a (* a)) @flow__C0) (b (c d))) (c (b d)))

@flow__C0 = ({$([+1] a) c} ((e f) ({b d} h)))
  & @down ~ (a (b (g h)))
  & @warp ~ (c (d (e (f g))))

@gen = (?(((a a) @gen__C0) b) b)

@gen__C0 = ({a d} ({$([*2] $([+1] b)) $([*2] e)} (c f)))
  &! @gen ~ (a (b c))
  &! @gen ~ (d (e f))

@main = a
  & @sum ~ (20 (@main__C1 a))

@main__C0 = a
  & @gen ~ (20 (0 a))

@main__C1 = a
  & @sort ~ (20 (0 (@main__C0 a)))

@sort = (?(((a (* a)) @sort__C0) (b (c d))) (c (b d)))

@sort__C0 = ({$([+1] a) {c f}} ((d g) (b i)))
  & @flow ~ (a (b ((e h) i)))
  &! @sort ~ (c (0 (d e)))
  &! @sort ~ (f (1 (g h)))

@sum = (?(((a a) @sum__C0) b) b)

@sum__C0 = ({a c} ((b d) f))
  &! @sum ~ (a (b $([+] $(e f))))
  &! @sum ~ (c (d e))

@swap = (?((@swap__C0 @swap__C1) (a (b c))) (b (a c)))

@swap__C0 = (b (a (a b)))

@swap__C1 = (* (a (b (a b))))

@warp = (?((@warp__C0 @warp__C1) (a (b (c d)))) (c (b (a d))))

@warp__C0 = ({a e} ({$([>] $(a b)) d} ($([^] $(b c)) f)))
  & @swap ~ (c (d (e f)))

@warp__C1 = ({a f} ((d i) ((c h) ({b g} ((e j) (k l))))))
  &! @warp ~ (f (g (h (i (j l)))))
  &! @warp ~ (a (b (c (d (e k)))))

Which performs about ~12000 MIPS on RTX 4090.

Currently, Bend generates the following output:

@down = (?(((* (a a)) @down__C0) b) b)

@down__C0 = ({a e} ({b f} ((c g) (d h))))
  &!@flow ~ (a (b (c d)))
  &!@flow ~ (e (f (g h)))

@flow = (?(((* (a a)) @flow__C0) b) b)

@flow__C0 = ({$([+0x0000001] a) c} ({b d} ((e f) h)))
  & @down ~ (a (b (g h)))
  & @warp ~ (c (d (e (f g))))

@gen = (?(((a a) @gen__C0) b) b)

@gen__C0 = ({a d} ({$([*0x0000002] $([+0x0000001] b)) $([*0x0000002] e)} (c f)))
  &!@gen ~ (a (b c))
  &!@gen ~ (d (e f))

@main = c
  & @sum ~ (20 (b c))
  & @sort ~ (20 (0 (a b)))
  & @gen ~ (20 (0 a))

@sort = (?(((* (a a)) @sort__C0) b) b)

@sort__C0 = ({$([+0x0000001] a) {c f}} (b ((d g) i)))
  & @flow ~ (a (b ((e h) i)))
  &!@sort ~ (c (0 (d e)))
  &!@sort ~ (f (1 (g h)))

@sum = (?(((a a) @sum__C0) b) b)

@sum__C0 = ({a c} ((b d) f))
  &!@sum ~ (a (b $([+] $(e f))))
  &!@sum ~ (c (d e))

@swap = (?((@swap__C0 @swap__C1) a) a)

@swap__C0 = (a (b (a b)))

@swap__C1 = (* (b (a (a b))))

@warp = (?((@warp__C0 @warp__C1) a) a)

@warp__C0 = ($([^] $(b c)) ({$([>] $(a b)) d} ({a e} f)))
  & @swap ~ (c (d (e f)))

@warp__C1 = ({a f} ({b g} ((c h) ((d i) ((e j) (k l))))))
  &!@warp ~ (f (g (h (i (j l)))))
  &!@warp ~ (a (b (c (d (e k)))))

Which performs about 6000 MIPS, or 50% slower.

Tested using the last version of HVM.

Can you please investigate?

@developedby @kings177

System Settings

.

Additional context

No response

Skimming over the output, i notice two changes:

=============

Some differences seem to come from changes in either the linearization or eta-reduction. The new version generates some definitions smaller, probably by reordering some lambdas which allows them to be eta-reduced (basically generating a better linearization).

We can see this in the swap function

# old
swap = λa λb λc ((switch a {0: swap_c0; _: swap_c1}) c b)
swap_c0 = λa λb (b, a)
swap_c1 = λa λb (a, b)

# new
swap = λa switch a {0: swap_c0; _: swap_c1}
swap_c0 = λa λb (a, b)
swap_c1 = λa λb (b, a)

The same thing happens in down, flow and warp.

This being the source of the slowdown would be problematic because the new version is strictly smaller, with fewer rewrites. I don't think I can predict what exactly is the effect of ordering the arguments on performance, so it would be hard to generate always the fastest code.

================

In the old version, we lifted the calls to sort and gen in main to their own separate functions.
This lets sum start one iteration earlier than sort and both start earlier than gen.

We have to do this because otherwise in a lot of cases a program would just return a lazy reference without calling it.
For example, the three calls in main could be lifted together into their own definition and swapped by a reference. If we didn't inline the reference then the program would do nothing.

Something like this:

main = do_main
do_main = sum(20, sort(20, 0, gen(20, 0)))

Here it's quite stupid, but in a more general setting this was happening automatically a lot. It's not easy to decide when we need and when we don't need to to this, so it's always done.

This is quite easy to test, we just have to manually write the old one.

If this is the source of the slowdown, then we need to figure out a better heuristic for when to lift the functions in main that maximizes the speed of execution. Although in this case I don't think that's the problem since both sum and sort depend on the output generated by gen to do anything, so they should become somewhat synchronized automatically.

@kings177 can you test the second thing I said and see if it makes any difference?
I mean writing main like this

main = sum(20, sort(20, 0, gen(20, 0)))

vs like this

main = sum(20, main_1)
main_1 = sort(20, 0, main_2)
main_2 = gen(20, 0)

@developedby tried with:

def main:
  return sum(20, sort(20, 0, gen(20, 0)))
def main:
  main = sum(20, sort(20, 0, gen(20, 0)))
  return main
def main:
  main_2 = gen(20, 0)
  main_1 = sort(20, 0, main_2)
  main = sum(20, main_1)
  return main

no difference at all when running.

also, they all generate the same exact hvm file.

no difference at all when running.

also, they all generate the same exact hvm file.

What I meant was this

def main():
  return sum(20, main_1)
def main_1():
  return sort(20, 0, main_2)
def main2():
  return gen(20, 0)

tried doing it that way too, same result, 6200~MIPS