Inductor Op Lowering
jeromeku opened this issue · 3 comments
Thanks for the great project!
I'm trying to understand how inductor
lowers the ops in the fx
graph to actual kernels -- specifically the optimization / tuning that determines the actual kernel implementations that are codegen'ed.
For example, in this blogpost, it is mentioned that the GEMV kernels generated by torch.compile
are faster than handwritten / proprietary kernels from cuBlas
and FlashAttention
.
I'd like to better understand the lowering passes that enables this:
- Stepping through the compilation process in the debugger gets a bit muddled through the various layers of abstractions (more likely that I need to get better at debugging)
- I've reviewed
select_algorithm.py
,triton_heuristics.py
, the mm-specific kernels directory withininductor
, etc. but am having trouble putting it all together.
Can depyf
provide greater visibility into this process?
Hi, thanks for your interest!
depyf
can help you understand the final results of inductor
lowering, which might help you understand the inner process.
For example, this file shows how inductor generates OpenMP files. On cuda GPUs, you will find triton code instead, like the following:
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_youkaichao/o4/co4orw7b7wc36g42g5jan6qohnuadq5mditczxn7irfu57wam4bu.py
# Source Nodes: [add, exp, gt, mean, neg, x_1], Original ATen: [aten.add, aten.exp, aten.gt, aten.mean, aten.mul, aten.neg, aten.reciprocal]
# add => add
# exp => exp
# gt => gt
# mean => mean
# neg => neg
# x_1 => mul, reciprocal
triton_per_fused_add_exp_gt_mean_mul_neg_reciprocal_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, persistent_reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@persistent_reduction(
size_hints=[1, 128],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*i1', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_exp_gt_mean_mul_neg_reciprocal_0', 'mutated_arg_names': []}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 1
rnumel = 100
RBLOCK: tl.constexpr = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
tmp1 = -tmp0
tmp2 = tl.exp(tmp1)
tmp3 = 5.0
tmp4 = tmp2 + tmp3
tmp5 = 1 / tmp4
tmp6 = 1.0
tmp7 = tmp5 * tmp6
tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK])
tmp10 = tl.where(rmask, tmp8, 0)
tmp11 = tl.sum(tmp10, 1)[:, None]
tmp12 = 100.0
tmp13 = tmp11 / tmp12
tmp14 = 0.5
tmp15 = tmp13 > tmp14
tl.store(out_ptr0 + (tl.broadcast_to(r0, [XBLOCK, RBLOCK])), tmp7, rmask)
tl.store(out_ptr2 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp15, None)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2 = args
args.clear()
assert_size_stride(primals_1, (100, ), (1, ))
assert_size_stride(primals_2, (100, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((100, ), device='cuda', dtype=torch.float32)
buf2 = empty((), device='cuda', dtype=torch.bool)
# Source Nodes: [add, exp, gt, mean, neg, x_1], Original ATen: [aten.add, aten.exp, aten.gt, aten.mean, aten.mul, aten.neg, aten.reciprocal]
stream0 = get_cuda_stream(0)
triton_per_fused_add_exp_gt_mean_mul_neg_reciprocal_0.run(primals_1, buf0, buf2, 1, 100, grid=grid(1), stream=stream0)
return (primals_2, buf0, buf2, primals_1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((100, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_2 = rand_strided((100, ), (1, ), device='cuda:0', dtype=torch.float32)
return print_performance(lambda: call([primals_1, primals_2]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
The usage is simple: just wrap your code with a context manager:
import depyf
with depyf.prepare_debug(None, dump_src_dir="/path/to/a/directory/you/want/to/save"):
# your code before, with `torch.compile`.
Thanks -- I can already generate those files by setting the requisite TORCH_INDUCTOR / TORCH_DYNAMO log / debug flags. Was interested more in interactively understanding the lowering process focused on how inductor
maps the optimized fx
graph to codegen'ed triton kernels.
Maybe you are interested in the https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py file? The entrypoint of inductor is:
And you can set a breakpoint there for debugging.