cuda IMA using custom triton kernel with compile
RobertCsordas opened this issue ยท 23 comments
๐ Describe the bug
Hi,
I'm trying to make my MoE Triton kernel work with torch.compile(). I know that this is not supported in the current stable version, but it is in the nightly (at least it worked with the simple kernels I tried). However, when I try to use my actual kernel, it fails.
There are two independent issues: one is the "Illegal getattr invocation stride in strict mode" issue (see the log below), which seems to prevent compilation but doesn't seem to be fatal. However the "RuntimeError: CUDA error: an illegal memory access was encountered" problem is fatal. Note that both the example and the full kernel work well without compile, and without running into illegal memory access issues.
Interestingly, if I remove an unused, static IF (which is never true in this example, and it depends only on external arguments), the code works. Also, if there is only one option in the triton.autotune(), it works as well. I marked both places with a comment starting with "!!!!!!" in the code below.
This is the simplified code I was able to come up with:
import torch
import triton
import triton.language as tl
# CVMM from: https://github.com/RobertCsordas/moe_layer/blob/master/triton_src/moe_layer/cvmm.py
# Based on: https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py
from typing import Union, Optional
from dataclasses import dataclass
@dataclass
class CVMMSel:
raw_sel: torch.Tensor
sel: torch.Tensor
sel_index: torch.Tensor
out_index: Optional[torch.Tensor] = None
def cvmm_prepare_sel(sel: torch.Tensor, n_experts: int) -> CVMMSel:
fsel = sel.flatten()
ssel, sel_index = fsel.sort()
return CVMMSel(sel, ssel.view_as(sel), sel_index, None)
# !!!!!! Leaving just one autotune config solves the "RuntimeError: CUDA error: an illegal memory access was
# encountered" problem !!!!!!
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K']
)
@triton.jit
def cvmm_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
# Matrix dimensions
M, N, K,
stride_cm, stride_cn,
stride_index, stride_sel, stride_out_index,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_m = first_pid_m + (pid % group_size_m)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
remap_offs_am = tl.load(index_ptr + stride_index * offs_am)
# Create offset pointers
c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
# !!!!!! Removing this IF solves the "RuntimeError: CUDA error: an illegal memory access was encountered" problem,
# even though it is always False in this example !!!!!!
# To test it, keep the else branch.
if out_index_ptr is not None:
remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)
else:
remap_offs_cm = remap_offs_am
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def cvmm_triton(x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, keys: torch.Tensor, out_dtype: torch.dtype, out_index: Optional[torch.Tensor] = None):
x = x.flatten(end_dim=-2)
assert x.shape[-1] == keys.shape[1]
sel_shape = sel.shape
sel = sel.flatten()
M = sel.shape[0]
O, K, N = keys.shape
# Allocates output.
out = torch.empty((M, N), device=x.device, dtype=out_dtype)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
cvmm_kernel[grid](
x, keys, out, sel_index, sel, None,
M, N, K,
out.stride(0), out.stride(1),
sel_index.stride(0), sel.stride(0), 0,
)
return out.view(*sel_shape, N)
class CVMM(torch.autograd.Function):
warned = False
@staticmethod
def forward(ctx, x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, keys: torch.Tensor, out_index: Optional[torch.Tensor] = None):
ctx.save_for_backward(x, keys, sel, sel_index, out_index)
out_type = torch.float16 if torch.is_autocast_enabled() else x.dtype
res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
ctx.op_type = out_type
ctx.keys_type = keys.dtype
ctx.is_autocast = torch.is_autocast_enabled()
return res
@staticmethod
def backward(ctx, grad_output):
x, keys, sel, sel_index, out_index = ctx.saved_tensors
keys_dt = keys
grad_x_full = cvmm_triton(grad_output, sel_index, sel, keys_dt.transpose(1,2), ctx.op_type, None)
grad_x = grad_x_full.view_as(x)
return grad_x, None, None, None, None
def cvmm(x: torch.Tensor, sel: Union[torch.Tensor, CVMMSel], keys: torch.Tensor):
if not isinstance(sel, CVMMSel):
sel = cvmm_prepare_sel(sel, keys.shape[0])
return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index)
# Compile test
class Model(torch.nn.Module):
def forward(self, x, sel, w):
return cvmm(x, sel, w)
model = torch.compile(Model().cuda())
# model = Model().cuda()
torch.manual_seed(0)
n_experts = 8
n_channels = 64
expert_size = 64
bs = 64
device = torch.device("cuda")
dtype = torch.float16
keys = torch.nn.Parameter(torch.randn(n_experts, n_channels, expert_size, dtype=dtype, device=device))
testvec = torch.randn(bs, n_channels, dtype=dtype, device=device)
sel = torch.randint(0, n_experts, (bs,), dtype=torch.int32, device=device)
print(model(testvec, sel, keys).shape)
Error logs
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_bwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Illegal getattr invocation stride in strict mode
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 231, in speculate_subgraph
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] output = f.call_function(tx, args, sub_kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_user_function_return(
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return inner_fn(self, inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.call_function(fn, argsvars.items, kwargsvars.items)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return self.obj.call_method(tx, self.name, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 505, in call_method
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_call(tx, backward, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return inner_fn(self, inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.call_function(fn, args, {})
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_user_function_return(
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1303, in LOAD_ATTR
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] result = BuiltinVariable(getattr).call_function(
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 651, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] result = handler(tx, *args, **kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1229, in call_getattr
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] return obj.var_getattr(tx, name).clone(source=source)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 218, in var_getattr
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] unimplemented(f"Illegal getattr invocation {name} in strict mode")
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] raise Unsupported(msg)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_bwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] Illegal getattr invocation stride in strict mode
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 231, in speculate_subgraph
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] output = f.call_function(tx, args, sub_kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_user_function_return(
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return inner_fn(self, inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.call_function(fn, argsvars.items, kwargsvars.items)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return self.obj.call_method(tx, self.name, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 505, in call_method
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_call(tx, backward, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return inner_fn(self, inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.call_function(fn, args, {})
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return tx.inline_user_function_return(
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] tracer.run()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] and self.step()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1303, in LOAD_ATTR
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] result = BuiltinVariable(getattr).call_function(
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 651, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] result = handler(tx, *args, **kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1229, in call_getattr
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] return obj.var_getattr(tx, name).clone(source=source)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 218, in var_getattr
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] unimplemented(f"Illegal getattr invocation {name} in strict mode")
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] raise Unsupported(msg)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
Traceback (most recent call last):
File "/home/robert/rnn_generalization_test/compile_test3.py", line 167, in <module>
print(model(testvec, sel, keys).shape)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/robert/rnn_generalization_test/compile_test3.py", line 148, in forward
return cvmm(x, sel, w)
File "/home/robert/rnn_generalization_test/compile_test3.py", line 141, in cvmm
return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index)
File "/home/robert/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/robert/rnn_generalization_test/compile_test3.py", line 114, in forward
@staticmethod
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
return compiled_fn(full_args)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
return f(*args)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
return compiled_fw(args)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 864, in __call__
return self.get_current_callable()(inputs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 611, in run
return model(new_inputs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
return compiled_graph.compiled_artifact(inputs)
File "/tmp/torchinductor_robert/vr/cvrlscv7n2ne4vcxcvozlnvsvyiddsior53it7ifax2hfs2uosni.py", line 111, in call
cvmm_kernel_0.run(a_ptr=arg0_1, b_ptr=arg1_1, c_ptr=buf0, index_ptr=arg3_1, sel_ptr=arg2_1, out_index_ptr=None, M=64, N=64, K=64, stride_cm=64, stride_cn=1, stride_index=1, stride_sel=1, stride_out_index=0, grid=grid_wrapper_for_cvmm_kernel_0, stream=stream0)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 540, in run
self.autotune_to_one_config(*args, grid=grid, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 444, in autotune_to_one_config
timings = self.benchmark_all_configs(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 420, in benchmark_all_configs
timings = {
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 421, in <dictcomp>
launcher: self.bench(launcher, *args, **kwargs)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 392, in bench
return do_bench(kernel_call, rep=40, fast_flush=True)
File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/utils.py", line 167, in do_bench
return triton_do_bench(*args, **kwargs)[0]
File "/home/robert/.local/lib/python3.10/site-packages/triton/testing.py", line 103, in do_bench
torch.cuda.synchronize()
File "/home/robert/.local/lib/python3.10/site-packages/torch/cuda/__init__.py", line 801, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Minified repro
The minifier creates a directory structure which has no files. I'm not sure what I'm supposed to paste here. The directory structure looks like this:
~/rnn_generalization_test/torch_compile_debug >>> find .
.
./run_2023_12_07_14_23_32_124760-pid_906807
./run_2023_12_07_14_23_32_124760-pid_906807/minifier
./run_2023_12_07_14_23_32_124760-pid_906807/minifier/checkpoints
Versions
Collecting environment information...
PyTorch version: 2.2.0.dev20231206+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Manjaro Linux (x86_64)
GCC version: (GCC) 12.2.1 20230201
Clang version: 15.0.7
CMake version: version 3.25.0
Libc version: glibc-2.37
Python version: 3.10.10 (main, Mar 5 2023, 22:26:53) [GCC 12.2.1 20230201] (64-bit runtime)
Python platform: Linux-5.15.109-1-MANJARO-x86_64-with-glibc2.37
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA TITAN V
Nvidia driver version: 530.41.03
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.8.0
/usr/lib/libcudnn_adv_infer.so.8.8.0
/usr/lib/libcudnn_adv_train.so.8.8.0
/usr/lib/libcudnn_cnn_infer.so.8.8.0
/usr/lib/libcudnn_cnn_train.so.8.8.0
/usr/lib/libcudnn_ops_infer.so.8.8.0
/usr/lib/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i7-8700 CPU @ 3.20GHz
CPU family: 6
Model: 158
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 10
CPU(s) scaling MHz: 25%
CPU max MHz: 3200.0000
CPU min MHz: 800.0000
BogoMIPS: 6402.62
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm arat pln pts hwp hwp_notify hwp_act_window hwp_epp md_clear flush_l1d arch_capabilities
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 1.5 MiB (6 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Mitigation; Microcode
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] kmeans-pytorch==0.3
[pip3] numpy==1.24.2
[pip3] pytorch-lightning==1.9.0
[pip3] pytorch-triton==2.1.0+bcad9dabe1
[pip3] torch==2.2.0.dev20231206+cu121
[pip3] torch-dct==0.1.6
[pip3] torch-tb-profiler==0.4.1
[pip3] torchaudio==2.2.0.dev20231206+cu121
[pip3] torchdata==0.6.1
[pip3] torchmetrics==0.11.0
[pip3] torchpq==0.3.0.5
[pip3] torchtext==0.15.2
[pip3] torchvision==0.17.0.dev20231206+cu121
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @oulgen @aakhundov @bdhirsh @anijain2305 @peterbell10 @wconstab
Thanks for the report, I will take a look!
@zou3519 feel free to assign any triton kernel issues directly to me
It looks like calling tensor.stride is banned in autograd.function which is why you're seeing the first error
unimplemented(f"Illegal getattr invocation {name} in strict mode")
pytorch/torch/_dynamo/config.py
Lines 334 to 340 in 18d57dd
This results in a graph break. @zou3519 is there any way around this?
I am gonna look at the cuda error next.
We might be able to fix the autograd.Function thing. It's a bit unclear right now. A workaround is to use the torch.library API to create a new PyTorch custom operator wrapping the triton kernel (https://gist.github.com/zou3519/43393fa6807774ca11be7b797fe38a8b). Although this works with torch.compile, we don't end up leveraging our infrastructure for directly compiling user-defined triton kernels in this solution.
@aakhundov and I spent some time looking at this and it looks like when there's a None argument, triton does not generate a real argument in the PTX but inductor emitted triton autotuning does. We are looking to see if there's a simple way to fix this but in the mean time a straightforward workaround would be pass a tl.constexpr parameter to tell whether the tensor is None or not instead of using tensor's value
Dear @oulgen @zou3519 (cc @RobertCsordas),
I have adapted the code from @RobertCsordas with the workarounds proposed by @zou3519 and it seems to work. When I try to add a node for the backward function though I am running into trouble.
I expanded the MWE a bit to better reflect the true file/function, but made it self contained, i.e. you can just run that file and you will see the same exception that I see in practice.
from typing import Union, Optional
import torch
from dataclasses import dataclass
import triton
import triton.language as tl
# Based on https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py
@dataclass
class CVMMSel:
raw_sel: torch.Tensor
sel: torch.Tensor
sel_index: torch.Tensor
out_index: Optional[torch.Tensor] = None
reduction_weight: Optional[torch.Tensor] = None
def clone(self) -> 'CVMMSel':
return CVMMSel(self.raw_sel, self.sel, self.sel_index, self.out_index, self.reduction_weight)
def cvmm_prepare_sel(sel: torch.Tensor, n_experts: int) -> CVMMSel:
fsel = sel.flatten()
ssel, sel_index = fsel.sort()
return CVMMSel(sel, ssel.view_as(sel), sel_index, None)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K', 'float32', 'allow_tf32']
)
@triton.jit
def cvmm_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bo, stride_bk, stride_bn,
stride_cm, stride_cn,
stride_index, stride_sel, stride_out_index,
out_index_is_none: tl.constexpr,
float32: tl.constexpr, allow_tf32: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_n = (pid % num_pid_in_group) // group_size_m
pid_m = first_pid_m + (pid % group_size_m)
sel_first = tl.load(sel_ptr + pid_m * BLOCK_SIZE_M * stride_sel)
sel_last = tl.load(sel_ptr + (min((pid_m + 1) * BLOCK_SIZE_M, M) - 1) * stride_sel)
sel_all = tl.load(sel_ptr + stride_sel * ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M))
for matrix_id in range(sel_first, sel_last + 1):
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
remap_offs_am = tl.load(index_ptr + stride_index * offs_am)
# Create offset pointers
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (remap_offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + matrix_id * stride_bo + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if not float32:
a = a.to(tl.float16)
b = b.to(tl.float16)
accumulator += tl.dot(a, b, allow_tf32=allow_tf32)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if not float32:
c = accumulator.to(tl.float16)
else:
c = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
if out_index_is_none:
remap_offs_cm = remap_offs_am
else:
remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = ((offs_cm[:, None] < M) & (sel_all[:, None] == matrix_id)) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 128}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 4}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 128}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 16}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 16}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
],
key=['M', 'N', 'K', 'float32_out', 'allow_tf32', 'op_float16'], reset_to_zero = ['c_ptr']
)
@triton.jit
def cvmm_backward_kernel3(
# Pointers to matrices
a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_co, stride_cm, stride_cn,
stride_index, stride_sel, stride_out_index,
out_index_is_none: tl.constexpr,
float32_out: tl.constexpr, allow_tf32: tl.constexpr, op_float16: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, K_BLOCKS: tl.constexpr
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
k_block_id = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
a_ptrs_this = a_ptr + offs_am[:, None] * stride_am
b_ptrs_this = b_ptr + offs_bn[None, :] * stride_bn
# Kactual = end_i - start_i
# Nblocks = (Kactual + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K
# WORK_PER_WORKER = (Nblocks + K_BLOCKS - 1) // K_BLOCKS
# WORK_PER_WORKER = WORK_PER_WORKER if WORK_PER_WORKER > MIN_WORK_SIZE else MIN_WORK_SIZE
# # Kloop_start = (Kactual + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K
# first_block_k = k_block_id * WORK_PER_WORKER
# last_block_k = min((k_block_id+1) * WORK_PER_WORKER, Nblocks)
block_start_index = k_block_id * BLOCK_SIZE_K * K_BLOCKS
block_end_index = min(block_start_index + BLOCK_SIZE_K * K_BLOCKS, K) - 1
first_mat = tl.load(sel_ptr + stride_sel * block_start_index)
last_mat = tl.load(sel_ptr + stride_sel * block_end_index)
for matrix_index in range(first_mat, last_mat + 1):
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
start_i = block_start_index
end_i = block_end_index + 1
while start_i < end_i:
middle = (start_i + end_i) // 2
middle_matrix = tl.load(sel_ptr + middle * stride_sel)
if middle_matrix < matrix_index:
start_i = middle + 1
else:
end_i = middle
# # Continue binary search: find the first matrix that is > matrix_index
start_i2 = start_i
end_i = block_end_index + 1
while start_i2 < end_i:
middle = (start_i2 + end_i) // 2
middle_matrix = tl.load(sel_ptr + middle * stride_sel)
if middle_matrix <= matrix_index:
start_i2 = middle + 1
else:
end_i = middle
end_i = start_i2
count = end_i - start_i
block_mem_indices_f_base = start_i + tl.arange(0, BLOCK_SIZE_K)
if count > 0:
for k in range((count + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K):
# block_mem_indices = (k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
block_mem_indices_f = block_mem_indices_f_base + k * BLOCK_SIZE_K
block_mem_indices = block_mem_indices_f % K
a_index = tl.load(index_ptr + stride_index * block_mem_indices)
if out_index_is_none:
b_index = a_index
else:
b_index = tl.load(out_index_ptr + stride_out_index * block_mem_indices)
sel_ok = block_mem_indices_f < end_i
a_ptrs = a_ptrs_this + a_index[None, :] * stride_ak
b_ptrs = b_ptrs_this + b_index[:, None] * stride_bk
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=sel_ok[None, :], other=0.0)
b = tl.load(b_ptrs, mask=sel_ok[:, None], other=0.0)
if op_float16:
a = a.to(tl.float16)
b = b.to(tl.float16)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b, allow_tf32=allow_tf32)
if float32_out:
c = accumulator
else:
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_co * matrix_index + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# tl.store(c_ptrs, c, mask=c_mask)
tl.atomic_add(c_ptrs, c, mask=c_mask)
torch.library.define("mylib::cvmm_triton", "(Tensor x, Tensor sel_index, Tensor sel, Tensor keys, ScalarType out_dtype, Tensor out_index) -> Tensor")
@torch.library.impl("mylib::cvmm_triton", "default")
def cvmm_triton(
x: torch.Tensor,
sel_index: torch.Tensor,
sel: torch.Tensor,
keys: torch.Tensor,
out_dtype: torch.dtype,
out_index: torch.Tensor
):
x = x.flatten(end_dim=-2)
assert x.shape[-1] == keys.shape[1]
sel_shape = sel.shape
sel = sel.flatten()
M = sel.shape[0]
O, K, N = keys.shape
# Allocates output.
out = torch.empty((M, N), device=x.device, dtype=out_dtype)
# out = torch.zeros((M, N), device=x.device, dtype=out_dtype)
# 1D launch kernel where each block gets its own program.
# expected_m_per_matrix = int(math.ceil(M / O * 1.5))
# expected_m_per_matrix = M
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
out_index_is_none = False
if out_index.numel() == 1 and out_index == -1:
out_index_is_none = True
cvmm_kernel[grid](
x, keys, out, sel_index, sel, out_index,
M, N, K,
x.stride(0), x.stride(1),
keys.stride(0), keys.stride(1), keys.stride(2),
out.stride(0), out.stride(1),
sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),
out_index_is_none=out_index_is_none,
float32=out.dtype==torch.float32, allow_tf32=False, #torch.backends.cuda.matmul.allow_tf32
)
return out.view(*sel_shape, N)
@torch.library.impl_abstract("mylib::cvmm_triton", cvmm_triton)
def cvmm_triton_abstract(x, sel_idx, sel, keys, out_dtype, out_index):
sel_shape = sel.shape
sel = sel.flatten()
M = sel.shape[0]
O, K, N = keys.shape
out = torch.empty((M, N), device=x.device, dtype=out_dtype)
sel_shape = sel.shape
return out.view(*sel_shape, N)
torch.library.define("mylib::cvmm_triton_backward", "(Tensor x, Tensor sel_index, Tensor sel, Tensor grads, int n_experts, ScalarType key_dtype, bool op_float16, Tensor out_index) -> Tensor")
@torch.library.impl("mylib::cvmm_triton_backward", "default")
def cvmm_triton_backward(
x: torch.Tensor,
sel_index: torch.Tensor,
sel: torch.Tensor,
grads: torch.Tensor,
n_experts: int,
key_dtype: torch.dtype,
op_float16: bool,
out_index: torch.Tensor
):
x = x.flatten(end_dim=-2)
x = x.transpose(0, 1)
grads = grads.flatten(end_dim=-2)
sel = sel.flatten()
M, _ = x.shape
K, N = grads.shape
out = torch.zeros((n_experts, M, N), device=x.device, dtype=key_dtype)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), triton.cdiv(K, META['BLOCK_SIZE_K'] * META['K_BLOCKS'])
)
out_index_is_none = False
if out_index.numel() == 1 and out_index == -1:
out_index_is_none = True
cvmm_backward_kernel3[grid](
x, grads, out, sel_index, sel, out_index,
M, N, K,
x.stride(0), x.stride(1),
grads.stride(0), grads.stride(1),
out.stride(0), out.stride(1), out.stride(2),
sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),
out_index_is_none=out_index_is_none,
float32_out=out.dtype == torch.float32,
op_float16=op_float16,
allow_tf32=False #torch.backends.cuda.matmul.allow_tf32
)
return out
@torch.library.impl_abstract("mylib::cvmm_triton_backward", cvmm_triton_backward)
def cvmm_triton_backward_abstract(
x: torch.Tensor,
sel_index: torch.Tensor,
sel: torch.Tensor,
grads: torch.Tensor,
n_experts: int,
key_dtype: torch.dtype,
op_float16: bool,
out_index: torch.Tensor
):
x = x.flatten(end_dim=-2)
x = x.transpose(0, 1)
grads = grads.flatten(end_dim=-2)
M, _ = x.shape
_, N = grads.shape
out = torch.zeros((n_experts, M, N), device=x.device, dtype=key_dtype)
return out
class CVMM(torch.autograd.Function):
warned = False
@staticmethod
def forward(
ctx,
x: torch.Tensor,
sel_index: torch.Tensor,
sel: torch.Tensor,
keys: torch.Tensor,
out_index: Optional[torch.Tensor] = None,
reduction_weight: Optional[torch.Tensor] = None
):
ctx.save_for_backward(x, keys, sel, sel_index, out_index, reduction_weight)
out_type = torch.float16 if torch.is_autocast_enabled() else x.dtype
if out_index is None:
out_index = torch.tensor(-1).cuda()
res = torch.ops.mylib.cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
# res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
if reduction_weight is not None:
res = res.view(*reduction_weight.shape, res.shape[-1])
res = (reduction_weight.unsqueeze(-2).type_as(res) @ res).squeeze(-2)
ctx.op_type = out_type
ctx.keys_type = keys.dtype
ctx.is_autocast = torch.is_autocast_enabled()
return res
@staticmethod
def backward(ctx, grad_output):
x, keys, sel, sel_index, out_index, reduction_weight = ctx.saved_tensors
keys_dt = keys
# Backward for weight
if reduction_weight is not None:
# Project back the grads with he reduction weight, so the grad for the weight matrix is ok
grad_output_w = reduction_weight.unsqueeze(-1).type_as(grad_output) @ grad_output.unsqueeze(-2)
# print("no none", grad_output_w.shape)
else:
grad_output_w = grad_output
# print("none", grad_output_w.shape)
out_index_is_none = False
if out_index is None:
out_index_is_none = True
out_index = torch.tensor(-1).cuda()
# grad_w = cvmm_triton_backward(
# x,
# sel_index,
# sel,
# grad_output_w,
# keys_dt.shape[0],
# ctx.keys_type,
# ctx.is_autocast,
# out_index=out_index
# )
grad_w = torch.ops.mylib.cvmm_triton_backward(
x,
sel_index,
sel,
grad_output_w,
keys_dt.shape[0],
ctx.keys_type,
ctx.is_autocast,
out_index=out_index
)
# Backward for input and reduction weight
grad_w_off = None
bw_index = sel_index if out_index_is_none else out_index
bw_index_out = torch.tensor(-1).cuda()
if reduction_weight is not None:
# Hack the output indices to emulate repeats
bw_index_out = bw_index
bw_index = bw_index // reduction_weight.shape[-1]
grad_x_full = torch.ops.mylib.cvmm_triton(
grad_output,
bw_index,
sel,
keys_dt.transpose(1,2),
ctx.op_type,
bw_index_out
)
grad_x_full = grad_x_full.view(*x.shape[:-1], -1, x.shape[-1])
if reduction_weight is not None:
# grad_x_full is the unscaled grad. For the input, we have to scale it, for the reduction wegiht,
# we have to compute dot products with the input.
grad_x = (reduction_weight.view(*grad_x_full.shape[:-1]).unsqueeze(-2).type_as(grad_x_full) @ grad_x_full).squeeze(-2)
grad_w_off = (grad_x_full.type_as(reduction_weight) @ x.unsqueeze(-1).type_as(reduction_weight)).squeeze(-1).view_as(reduction_weight)
elif grad_x_full.shape[-2] != 1:
grad_x = grad_x_full.sum(-2)
else:
grad_x = grad_x_full
grad_x = grad_x.view_as(x)
return grad_x, None, None, grad_w, None, grad_w_off
def cvmm(x: torch.Tensor, sel: Union[torch.Tensor, CVMMSel], keys: torch.Tensor):
if not isinstance(sel, CVMMSel):
sel = cvmm_prepare_sel(sel, keys.shape[0])
return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index, sel.reduction_weight)
def cvmm_prepare_sel2(sel: torch.Tensor, w: Optional[torch.Tensor] = None) -> CVMMSel:
# Has multiple selections for each batch element
n_per_batch = sel.shape[-1]
# indices = torch.arange(sel.nelement() // n_per_batch, device=sel.device, dtype=torch.int32)
# indices = indices.repeat_interleave(n_per_batch).flatten()
fsel = sel.flatten()
ssel, sel_index = fsel.sort()
# in_index = indices[sel_index]
in_index = sel_index // n_per_batch
return CVMMSel(sel, ssel.view_as(sel), in_index, sel_index, w)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.n_heads = 8
self.n_experts = 8
self.expert_size = 64
self.k_vec_dim = 128
self.v_dim = 128
self.keys = torch.nn.Parameter(
torch.empty(self.n_experts, self.k_vec_dim, self.expert_size)
)
self.values = torch.nn.Parameter(
torch.empty(self.n_experts, self.expert_size, self.v_dim)
)
self.expert_sel = torch.nn.Linear(self.k_vec_dim, self.n_experts, bias=False)
self.sel_activation = torch.nn.Sigmoid()
def compute_scores(self, input: torch.Tensor, index: CVMMSel) -> torch.Tensor:
scores = cvmm(input, index, self.keys)
return scores
def forward(self, input: torch.Tensor):
sel = sel_raw = self.expert_sel(input)
sel = self.sel_activation(sel)
sel_val, sel_index = sel.topk(self.n_heads, dim=-1, sorted=False)
# Preprocess the selection indices. They will be needed for both layers and save some time
sel_indices = cvmm_prepare_sel2(sel_index.int())
# "Up-projection" layer for each head
scores = self.compute_scores(input, sel_indices)
# Down projection layer for each head
sel_indices = sel_indices.clone()
sel_indices.reduction_weight = sel_val
sel_indices.sel_index = sel_indices.out_index
sel_indices.out_index = None
out = cvmm(scores, sel_indices, self.values)
return out
model = Model().to(torch.float16).cuda()
model = torch.compile(model)
torch.manual_seed(0)
n_experts = 8
n_channels = 128
expert_size = 64
bs = 64
device = torch.device("cuda")
dtype = torch.float16
testvec = torch.randn(bs, n_channels, dtype=dtype, device=device)
out = model(testvec)
loss = out.sum()
loss.backward()
print(model.keys.grad.shape)
print(out.shape)
I am using the current nightly build. The exception I get is:
Traceback (most recent call last):
File "/u/julianb/transformers/src/transformers/models/sigma_moe/triton_src/moe_layer/cvmm.py", line 624, in <module>
loss.backward()
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/autograd/graph.py", line 681, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 754, in backward
out = call_compiled_backward()
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 706, in call_compiled_backward
out = call_func_at_runtime_with_args(
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
return fn(*args, **kwargs)
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 863, in __call__
return self.get_current_callable()(inputs)
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 608, in run
return model(new_inputs)
File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 891, in _run_from_cache
return compiled_graph.compiled_artifact(inputs)
File "/tmp/torchinductor_julianb/xa/cxajzptzw7uzhcqg2mfni6rnorv2jmmlin4yd6qqvdq4saagxfl6.py", line 149, in call
assert_size_stride(cvmm_triton, (512, 64), (64, 1))
AssertionError: wrong number of dimensions
My guess what is happening is that in the backward pass, the function cvmm_triton
gets called with different input shapes, but I don't understand why that is not handled.
Any help would be appreciated.
I can repro this. The inductor generated code is: https://gist.github.com/oulgen/68f41bd6d28e1547a80081097bce9a86
I'll take a look at this but it might be next week or later since I am currently out on vacation
Not yet, I have been preoccupied with a silent accuracy bug. Hoping to take a look at this next week.
@jubueche I took a look at this today and it looks like the bug is not related to triton kernels. So I removed all the user defined triton kernels and the repro still fails with
AssertionError: wrong number of dimensions
onassert_size_stride(cvmm_triton, (512, 64), (64, 1))
The repro passes on eager mode and fails in inductor. So it looks like there's an overall PT2 issue here nothing to do with user defined triton kernels.
https://gist.github.com/oulgen/3d40cfc65077179654f3d0f71d706008
Removed triaged
so that the oncall can re-triage this for faster turn around time
The repro passes on eager mode and fails in inductor. So it looks like there's an overall PT2 issue here nothing to do with user defined triton kernels.
Removed triaged so that the oncall can re-triage this for faster turn around time
Not familiar with that. Should I open a new issue somewhere else or is someone else looking at this?
I removed the triaged
tag so the PT2 oncall will retriage it. I will make sure it gets assigned. No need to create a new issue.
Tentatively assigned to Brian.
https://gist.github.com/oulgen/3d40cfc65077179654f3d0f71d706008
The repro fails with backend="eager", so it may not be an inductor issue
I'm going to temporarily un-assign myself. This issue repro's with dynamo, so it is (likely?) a fake tensor issue (between other commitments and PTO, I won't be able to look at this for at least 3-4 weeks).
If someone else can take a look before then feel free - otherwise, I will look again when I have more bandwidth
This no longer repro's as an assert_size_stride
error, but now fails in the custom triton kernel. cc @oulgen would you take a look at this ?
Okay, using @oulgen's gist without custom triton it no repros. It does fail on both eager
and aot_eager
.
TL;DR, since the thread is long:
I cleaned up all the repros. You can find the latest repros here.
- this (raw triton kernel): causes a CUDA IMA: https://gist.github.com/zou3519/ff7da848d63d82d3324982cf1cd0ac0f
- this (triton kernel + custom ops) segfaults: https://gist.github.com/zou3519/09c4732a8aeea45e33960a3076745f6d
It's unclear if this is due to user error, a triton bug, or a pytorch error.
From internal triage meeting: oulgen still planning to look at this
I've been busy with some other tasks, haven't had a chance to look at this.
cc. @aakhundov would you be interested in picking this up? (from triage meeting)