thuml/depyf

Inductor Op Lowering

jeromeku opened this issue · 3 comments

Thanks for the great project!

I'm trying to understand how inductor lowers the ops in the fx graph to actual kernels -- specifically the optimization / tuning that determines the actual kernel implementations that are codegen'ed.

For example, in this blogpost, it is mentioned that the GEMV kernels generated by torch.compile are faster than handwritten / proprietary kernels from cuBlas and FlashAttention.

I'd like to better understand the lowering passes that enables this:

  • Stepping through the compilation process in the debugger gets a bit muddled through the various layers of abstractions (more likely that I need to get better at debugging)
  • I've reviewed select_algorithm.py, triton_heuristics.py, the mm-specific kernels directory within inductor, etc. but am having trouble putting it all together.

Can depyf provide greater visibility into this process?

Hi, thanks for your interest!

depyf can help you understand the final results of inductor lowering, which might help you understand the inner process.

For example, this file shows how inductor generates OpenMP files. On cuda GPUs, you will find triton code instead, like the following:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_youkaichao/o4/co4orw7b7wc36g42g5jan6qohnuadq5mditczxn7irfu57wam4bu.py
# Source Nodes: [add, exp, gt, mean, neg, x_1], Original ATen: [aten.add, aten.exp, aten.gt, aten.mean, aten.mul, aten.neg, aten.reciprocal]
# add => add
# exp => exp
# gt => gt
# mean => mean
# neg => neg
# x_1 => mul, reciprocal
triton_per_fused_add_exp_gt_mean_mul_neg_reciprocal_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, persistent_reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@persistent_reduction(
    size_hints=[1, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*i1', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_exp_gt_mean_mul_neg_reciprocal_0', 'mutated_arg_names': []}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    rnumel = 100
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r0 = rindex
    tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
    tmp1 = -tmp0
    tmp2 = tl.exp(tmp1)
    tmp3 = 5.0
    tmp4 = tmp2 + tmp3
    tmp5 = 1 / tmp4
    tmp6 = 1.0
    tmp7 = tmp5 * tmp6
    tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
    tmp10 = tl.where(rmask, tmp8, 0)
    tmp11 = tl.sum(tmp10, 1)[:, None]
    tmp12 = 100.0
    tmp13 = tmp11 / tmp12
    tmp14 = 0.5
    tmp15 = tmp13 > tmp14
    tl.store(out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp7, rmask)
    tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp15, None)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2 = args
    args.clear()
    assert_size_stride(primals_1, (100, ), (1, ))
    assert_size_stride(primals_2, (100, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty((100, ), device='cuda', dtype=torch.float32)
        buf2 = empty((), device='cuda', dtype=torch.bool)
        # Source Nodes: [add, exp, gt, mean, neg, x_1], Original ATen: [aten.add, aten.exp, aten.gt, aten.mean, aten.mul, aten.neg, aten.reciprocal]
        stream0 = get_cuda_stream(0)
        triton_per_fused_add_exp_gt_mean_mul_neg_reciprocal_0.run(primals_1, buf0, buf2, 1, 100, grid=grid(1), stream=stream0)
        return (primals_2, buf0, buf2, primals_1, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((100, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_2 = rand_strided((100, ), (1, ), device='cuda:0', dtype=torch.float32)
    return print_performance(lambda: call([primals_1, primals_2]), times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

The usage is simple: just wrap your code with a context manager:

import depyf
with depyf.prepare_debug(None, dump_src_dir="/path/to/a/directory/you/want/to/save"):
    # your code before, with `torch.compile`.

Thanks -- I can already generate those files by setting the requisite TORCH_INDUCTOR / TORCH_DYNAMO log / debug flags. Was interested more in interactively understanding the lowering process focused on how inductor maps the optimized fx graph to codegen'ed triton kernels.