pytorch/pytorch

cuda IMA using custom triton kernel with compile

RobertCsordas opened this issue ยท 23 comments

๐Ÿ› Describe the bug

Hi,

I'm trying to make my MoE Triton kernel work with torch.compile(). I know that this is not supported in the current stable version, but it is in the nightly (at least it worked with the simple kernels I tried). However, when I try to use my actual kernel, it fails.

There are two independent issues: one is the "Illegal getattr invocation stride in strict mode" issue (see the log below), which seems to prevent compilation but doesn't seem to be fatal. However the "RuntimeError: CUDA error: an illegal memory access was encountered" problem is fatal. Note that both the example and the full kernel work well without compile, and without running into illegal memory access issues.

Interestingly, if I remove an unused, static IF (which is never true in this example, and it depends only on external arguments), the code works. Also, if there is only one option in the triton.autotune(), it works as well. I marked both places with a comment starting with "!!!!!!" in the code below.

This is the simplified code I was able to come up with:

import torch

import triton
import triton.language as tl

# CVMM from: https://github.com/RobertCsordas/moe_layer/blob/master/triton_src/moe_layer/cvmm.py
# Based on: https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py

from typing import Union, Optional
from dataclasses import dataclass


@dataclass
class CVMMSel:
    raw_sel: torch.Tensor
    sel: torch.Tensor
    sel_index: torch.Tensor
    out_index: Optional[torch.Tensor] = None


def cvmm_prepare_sel(sel: torch.Tensor, n_experts: int) -> CVMMSel:
    fsel = sel.flatten()
    ssel, sel_index = fsel.sort()
    return CVMMSel(sel, ssel.view_as(sel), sel_index, None)



# !!!!!! Leaving just one autotune config solves the "RuntimeError: CUDA error: an illegal memory access was
# encountered" problem !!!!!!
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K']
)
@triton.jit
def cvmm_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
    # Matrix dimensions
    M, N, K,
    stride_cm, stride_cn,
    stride_index, stride_sel, stride_out_index,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    pid = tl.program_id(axis=0)

    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_n = (pid % num_pid_in_group) // group_size_m

    pid_m = first_pid_m + (pid % group_size_m)

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M

    remap_offs_am = tl.load(index_ptr + stride_index * offs_am)

    # Create offset pointers
    c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)


    # !!!!!! Removing this IF solves the "RuntimeError: CUDA error: an illegal memory access was encountered" problem,
    # even though it is always False in this example !!!!!!
    # To test it, keep the else branch.
    if out_index_ptr is not None:
        remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)
    else:
        remap_offs_cm = remap_offs_am

    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)



def cvmm_triton(x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, keys: torch.Tensor, out_dtype: torch.dtype, out_index: Optional[torch.Tensor] = None):
    x = x.flatten(end_dim=-2)
    assert x.shape[-1] == keys.shape[1]

    sel_shape = sel.shape
    sel = sel.flatten()

    M = sel.shape[0]
    O, K, N = keys.shape
    # Allocates output.
    out = torch.empty((M, N), device=x.device, dtype=out_dtype)

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )

    cvmm_kernel[grid](
        x, keys, out, sel_index, sel, None,
        M, N, K,
        out.stride(0), out.stride(1),
        sel_index.stride(0), sel.stride(0), 0,
    )

    return out.view(*sel_shape, N)


class CVMM(torch.autograd.Function):
    warned = False

    @staticmethod
    def forward(ctx, x: torch.Tensor, sel_index: torch.Tensor, sel: torch.Tensor, keys: torch.Tensor, out_index: Optional[torch.Tensor] = None):
        ctx.save_for_backward(x, keys, sel, sel_index, out_index)

        out_type = torch.float16 if torch.is_autocast_enabled() else x.dtype
        res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
        ctx.op_type = out_type
        ctx.keys_type = keys.dtype
        ctx.is_autocast = torch.is_autocast_enabled()
        return res

    @staticmethod
    def backward(ctx, grad_output):
        x, keys, sel, sel_index, out_index = ctx.saved_tensors

        keys_dt = keys

        grad_x_full = cvmm_triton(grad_output, sel_index, sel, keys_dt.transpose(1,2), ctx.op_type, None)
        grad_x = grad_x_full.view_as(x)

        return grad_x, None, None, None, None


def cvmm(x: torch.Tensor, sel: Union[torch.Tensor, CVMMSel], keys: torch.Tensor):
    if not isinstance(sel, CVMMSel):
        sel = cvmm_prepare_sel(sel, keys.shape[0])

    return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index)

# Compile test


class Model(torch.nn.Module):
    def forward(self, x, sel, w):
        return cvmm(x, sel, w)

model = torch.compile(Model().cuda())
# model = Model().cuda()


torch.manual_seed(0)
n_experts = 8
n_channels = 64
expert_size = 64
bs = 64

device = torch.device("cuda")
dtype = torch.float16

keys = torch.nn.Parameter(torch.randn(n_experts, n_channels, expert_size, dtype=dtype, device=device))
testvec = torch.randn(bs, n_channels, dtype=dtype, device=device)
sel = torch.randint(0, n_experts, (bs,), dtype=torch.int32, device=device)

print(model(testvec, sel, keys).shape)

Error logs

[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_bwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Illegal getattr invocation stride in strict mode
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 231, in speculate_subgraph
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     output = f.call_function(tx, args, sub_kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_user_function_return(
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return inner_fn(self, inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.call_function(fn, argsvars.items, kwargsvars.items)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return self.obj.call_method(tx, self.name, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 505, in call_method
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_call(tx, backward, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return inner_fn(self, inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.call_function(fn, args, {})
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_user_function_return(
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1303, in LOAD_ATTR
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     result = BuiltinVariable(getattr).call_function(
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 651, in call_function
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     result = handler(tx, *args, **kwargs)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1229, in call_getattr
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return obj.var_getattr(tx, name).clone(source=source)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 218, in var_getattr
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     unimplemented(f"Illegal getattr invocation {name} in strict mode")
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[2023-12-07 14:21:05,882] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_bwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] Illegal getattr invocation stride in strict mode
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 231, in speculate_subgraph
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     output = f.call_function(tx, args, sub_kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_user_function_return(
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return inner_fn(self, inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.call_function(fn, argsvars.items, kwargsvars.items)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 660, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return self.obj.call_method(tx, self.name, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 505, in call_method
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_call(tx, backward, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return inner_fn(self, inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.call_function(fn, args, {})
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     self.push(fn.call_function(self, args, kwargs))
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_function(tx, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_user_function_return(
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     tracer.run()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     and self.step()
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     getattr(self, inst.opname)(inst)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1303, in LOAD_ATTR
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     result = BuiltinVariable(getattr).call_function(
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 651, in call_function
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     result = handler(tx, *args, **kwargs)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1229, in call_getattr
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return obj.var_getattr(tx, name).clone(source=source)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/variables/tensor.py", line 218, in var_getattr
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     unimplemented(f"Illegal getattr invocation {name} in strict mode")
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[2023-12-07 14:21:05,929] [1/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
Traceback (most recent call last):
  File "/home/robert/rnn_generalization_test/compile_test3.py", line 167, in <module>
    print(model(testvec, sel, keys).shape)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/robert/rnn_generalization_test/compile_test3.py", line 148, in forward
    return cvmm(x, sel, w)
  File "/home/robert/rnn_generalization_test/compile_test3.py", line 141, in cvmm
    return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/robert/rnn_generalization_test/compile_test3.py", line 114, in forward
    @staticmethod
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 864, in __call__
    return self.get_current_callable()(inputs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 611, in run
    return model(new_inputs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_robert/vr/cvrlscv7n2ne4vcxcvozlnvsvyiddsior53it7ifax2hfs2uosni.py", line 111, in call
    cvmm_kernel_0.run(a_ptr=arg0_1, b_ptr=arg1_1, c_ptr=buf0, index_ptr=arg3_1, sel_ptr=arg2_1, out_index_ptr=None, M=64, N=64, K=64, stride_cm=64, stride_cn=1, stride_index=1, stride_sel=1, stride_out_index=0, grid=grid_wrapper_for_cvmm_kernel_0, stream=stream0)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 540, in run
    self.autotune_to_one_config(*args, grid=grid, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 444, in autotune_to_one_config
    timings = self.benchmark_all_configs(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 420, in benchmark_all_configs
    timings = {
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 421, in <dictcomp>
    launcher: self.bench(launcher, *args, **kwargs)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 392, in bench
    return do_bench(kernel_call, rep=40, fast_flush=True)
  File "/home/robert/.local/lib/python3.10/site-packages/torch/_inductor/utils.py", line 167, in do_bench
    return triton_do_bench(*args, **kwargs)[0]
  File "/home/robert/.local/lib/python3.10/site-packages/triton/testing.py", line 103, in do_bench
    torch.cuda.synchronize()
  File "/home/robert/.local/lib/python3.10/site-packages/torch/cuda/__init__.py", line 801, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Minified repro

The minifier creates a directory structure which has no files. I'm not sure what I'm supposed to paste here. The directory structure looks like this:

~/rnn_generalization_test/torch_compile_debug >>> find .                                                                                                                                                                                                                                                                                                    
.
./run_2023_12_07_14_23_32_124760-pid_906807
./run_2023_12_07_14_23_32_124760-pid_906807/minifier
./run_2023_12_07_14_23_32_124760-pid_906807/minifier/checkpoints

Versions

Collecting environment information...
PyTorch version: 2.2.0.dev20231206+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Manjaro Linux (x86_64)
GCC version: (GCC) 12.2.1 20230201
Clang version: 15.0.7
CMake version: version 3.25.0
Libc version: glibc-2.37

Python version: 3.10.10 (main, Mar  5 2023, 22:26:53) [GCC 12.2.1 20230201] (64-bit runtime)
Python platform: Linux-5.15.109-1-MANJARO-x86_64-with-glibc2.37
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA TITAN V
Nvidia driver version: 530.41.03
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.8.0
/usr/lib/libcudnn_adv_infer.so.8.8.0
/usr/lib/libcudnn_adv_train.so.8.8.0
/usr/lib/libcudnn_cnn_infer.so.8.8.0
/usr/lib/libcudnn_cnn_train.so.8.8.0
/usr/lib/libcudnn_ops_infer.so.8.8.0
/usr/lib/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   39 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          12
On-line CPU(s) list:             0-11
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Core(TM) i7-8700 CPU @ 3.20GHz
CPU family:                      6
Model:                           158
Thread(s) per core:              2
Core(s) per socket:              6
Socket(s):                       1
Stepping:                        10
CPU(s) scaling MHz:              25%
CPU max MHz:                     3200.0000
CPU min MHz:                     800.0000
BogoMIPS:                        6402.62
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm arat pln pts hwp hwp_notify hwp_act_window hwp_epp md_clear flush_l1d arch_capabilities
L1d cache:                       192 KiB (6 instances)
L1i cache:                       192 KiB (6 instances)
L2 cache:                        1.5 MiB (6 instances)
L3 cache:                        12 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-11
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:          Mitigation; IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Mitigation; Microcode
Vulnerability Tsx async abort:   Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] kmeans-pytorch==0.3
[pip3] numpy==1.24.2
[pip3] pytorch-lightning==1.9.0
[pip3] pytorch-triton==2.1.0+bcad9dabe1
[pip3] torch==2.2.0.dev20231206+cu121
[pip3] torch-dct==0.1.6
[pip3] torch-tb-profiler==0.4.1
[pip3] torchaudio==2.2.0.dev20231206+cu121
[pip3] torchdata==0.6.1
[pip3] torchmetrics==0.11.0
[pip3] torchpq==0.3.0.5
[pip3] torchtext==0.15.2
[pip3] torchvision==0.17.0.dev20231206+cu121
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @oulgen @aakhundov @bdhirsh @anijain2305 @peterbell10 @wconstab

oulgen commented

Thanks for the report, I will take a look!

@zou3519 feel free to assign any triton kernel issues directly to me

oulgen commented

It looks like calling tensor.stride is banned in autograd.function which is why you're seeing the first error

unimplemented(f"Illegal getattr invocation {name} in strict mode")

_autograd_backward_strict_mode_banned_ops = [
"stride",
"requires_grad",
"storage_offset",
"layout",
"data",
]

This results in a graph break. @zou3519 is there any way around this?

I am gonna look at the cuda error next.

We might be able to fix the autograd.Function thing. It's a bit unclear right now. A workaround is to use the torch.library API to create a new PyTorch custom operator wrapping the triton kernel (https://gist.github.com/zou3519/43393fa6807774ca11be7b797fe38a8b). Although this works with torch.compile, we don't end up leveraging our infrastructure for directly compiling user-defined triton kernels in this solution.

@RobertCsordas

@aakhundov and I spent some time looking at this and it looks like when there's a None argument, triton does not generate a real argument in the PTX but inductor emitted triton autotuning does. We are looking to see if there's a simple way to fix this but in the mean time a straightforward workaround would be pass a tl.constexpr parameter to tell whether the tensor is None or not instead of using tensor's value

Dear @oulgen @zou3519 (cc @RobertCsordas),
I have adapted the code from @RobertCsordas with the workarounds proposed by @zou3519 and it seems to work. When I try to add a node for the backward function though I am running into trouble.
I expanded the MWE a bit to better reflect the true file/function, but made it self contained, i.e. you can just run that file and you will see the same exception that I see in practice.

from typing import Union, Optional
import torch
from dataclasses import dataclass
import triton
import triton.language as tl

# Based on https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py


@dataclass
class CVMMSel:
    raw_sel: torch.Tensor
    sel: torch.Tensor
    sel_index: torch.Tensor
    out_index: Optional[torch.Tensor] = None
    reduction_weight: Optional[torch.Tensor] = None

    def clone(self) -> 'CVMMSel':
        return CVMMSel(self.raw_sel, self.sel, self.sel_index, self.out_index, self.reduction_weight)


def cvmm_prepare_sel(sel: torch.Tensor, n_experts: int) -> CVMMSel:
    fsel = sel.flatten()
    ssel, sel_index = fsel.sort()
    return CVMMSel(sel, ssel.view_as(sel), sel_index, None)


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K', 'float32', 'allow_tf32']
)
@triton.jit
def cvmm_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
    # Matrix dimensions
    M, N, K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_am, stride_ak,
    stride_bo, stride_bk, stride_bn,
    stride_cm, stride_cn,
    stride_index, stride_sel, stride_out_index,
    out_index_is_none: tl.constexpr,
    float32: tl.constexpr, allow_tf32: tl.constexpr,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)

    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_n = (pid % num_pid_in_group) // group_size_m

    pid_m = first_pid_m + (pid % group_size_m)

    sel_first = tl.load(sel_ptr + pid_m * BLOCK_SIZE_M * stride_sel)
    sel_last = tl.load(sel_ptr + (min((pid_m + 1) * BLOCK_SIZE_M, M) - 1) * stride_sel)
    sel_all = tl.load(sel_ptr + stride_sel * ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M))

    for matrix_id in range(sel_first, sel_last + 1):
        # ----------------------------------------------------------
        # Create pointers for the first blocks of A and B.
        # We will advance this pointer as we move in the K direction
        # and accumulate
        # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
        # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
        # See above `Pointer Arithmetics` section for details
        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

        remap_offs_am = tl.load(index_ptr + stride_index * offs_am)

        # Create offset pointers
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        a_ptrs = a_ptr + (remap_offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
        b_ptrs = b_ptr + matrix_id * stride_bo + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

        # -----------------------------------------------------------
        # Iterate to compute a block of the C matrix.
        # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
        # of fp32 values for higher accuracy.
        # `accumulator` will be converted back to fp16 after the loop.
        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            # Load the next block of A and B, generate a mask by checking the K dimension.
            # If it is out of bounds, set it to 0.
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
            # We accumulate along the K dimension.

            if not float32:
                a = a.to(tl.float16)
                b = b.to(tl.float16)

            accumulator += tl.dot(a, b, allow_tf32=allow_tf32)

            # Advance the ptrs to the next K block.
            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk


        if not float32:
            c = accumulator.to(tl.float16)
        else:
            c = accumulator

        # -----------------------------------------------------------
        # Write back the block of the output matrix C with masks.
        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)

        if out_index_is_none:
            remap_offs_cm = remap_offs_am
        else:
            remap_offs_cm = tl.load(out_index_ptr + stride_out_index * offs_am)

        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        c_ptrs = c_ptr + stride_cm * remap_offs_cm[:, None] + stride_cn * offs_cn[None, :]
        c_mask = ((offs_cm[:, None] < M) & (sel_all[:, None] == matrix_id)) & (offs_cn[None, :] < N)
        tl.store(c_ptrs, c, mask=c_mask)


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 128}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 4}, num_stages=4, num_warps=4),

        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
        # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 128}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 8}, num_stages=4, num_warps=4),

        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 16}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 16}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 64}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8, 'K_BLOCKS': 32}, num_stages=4, num_warps=4),
    ],
    key=['M', 'N', 'K', 'float32_out', 'allow_tf32', 'op_float16'], reset_to_zero = ['c_ptr']
)
@triton.jit
def cvmm_backward_kernel3(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr, index_ptr, sel_ptr, out_index_ptr,
    # Matrix dimensions
    M, N, K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
    # by to get the element one row down (A has M rows).
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_co, stride_cm, stride_cn,
    stride_index, stride_sel, stride_out_index,
    out_index_is_none: tl.constexpr,
    float32_out: tl.constexpr, allow_tf32: tl.constexpr, op_float16: tl.constexpr,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr, K_BLOCKS: tl.constexpr
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    k_block_id = tl.program_id(axis=1)

    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    # See above `Pointer Arithmetics` section for details
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.

    a_ptrs_this = a_ptr + offs_am[:, None] * stride_am
    b_ptrs_this = b_ptr + offs_bn[None, :] * stride_bn

    # Kactual = end_i - start_i
    # Nblocks = (Kactual + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K

    # WORK_PER_WORKER = (Nblocks + K_BLOCKS - 1) // K_BLOCKS
    # WORK_PER_WORKER = WORK_PER_WORKER if WORK_PER_WORKER > MIN_WORK_SIZE else MIN_WORK_SIZE


    # # Kloop_start = (Kactual + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K

    # first_block_k = k_block_id * WORK_PER_WORKER
    # last_block_k = min((k_block_id+1) * WORK_PER_WORKER, Nblocks)

    block_start_index = k_block_id * BLOCK_SIZE_K * K_BLOCKS
    block_end_index = min(block_start_index + BLOCK_SIZE_K * K_BLOCKS, K) - 1

    first_mat = tl.load(sel_ptr + stride_sel * block_start_index)
    last_mat = tl.load(sel_ptr + stride_sel * block_end_index)


    for matrix_index in range(first_mat, last_mat + 1):
        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

        start_i = block_start_index
        end_i = block_end_index + 1
        while start_i < end_i:
            middle = (start_i + end_i) // 2
            middle_matrix = tl.load(sel_ptr + middle * stride_sel)
            if middle_matrix < matrix_index:
                start_i = middle + 1
            else:
                end_i = middle


        # # Continue binary search: find the first matrix that is > matrix_index
        start_i2 = start_i
        end_i = block_end_index + 1
        while start_i2 < end_i:
            middle = (start_i2 + end_i) // 2
            middle_matrix = tl.load(sel_ptr + middle * stride_sel)
            if middle_matrix <= matrix_index:
                start_i2 = middle + 1
            else:
                end_i = middle

        end_i = start_i2

        count = end_i - start_i

        block_mem_indices_f_base = start_i  + tl.arange(0, BLOCK_SIZE_K)

        if count > 0:
            for k in range((count + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K):
                # block_mem_indices = (k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)) % K
                block_mem_indices_f = block_mem_indices_f_base + k * BLOCK_SIZE_K
                block_mem_indices = block_mem_indices_f % K
                a_index = tl.load(index_ptr + stride_index * block_mem_indices)
                if out_index_is_none:
                    b_index = a_index
                else:
                    b_index = tl.load(out_index_ptr + stride_out_index * block_mem_indices)
                sel_ok = block_mem_indices_f < end_i

                a_ptrs = a_ptrs_this + a_index[None, :] * stride_ak
                b_ptrs = b_ptrs_this + b_index[:, None] * stride_bk

                # Load the next block of A and B, generate a mask by checking the K dimension.
                # If it is out of bounds, set it to 0.
                a = tl.load(a_ptrs, mask=sel_ok[None, :], other=0.0)
                b = tl.load(b_ptrs, mask=sel_ok[:, None], other=0.0)

                if op_float16:
                    a = a.to(tl.float16)
                    b = b.to(tl.float16)

                # We accumulate along the K dimension.
                accumulator += tl.dot(a, b, allow_tf32=allow_tf32)

            if float32_out:
                c = accumulator
            else:
                c = accumulator.to(tl.float16)

            # -----------------------------------------------------------
            # Write back the block of the output matrix C with masks.
            offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
            offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
            c_ptrs = c_ptr + stride_co * matrix_index + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
            c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
            # tl.store(c_ptrs, c, mask=c_mask)
            tl.atomic_add(c_ptrs, c, mask=c_mask)


torch.library.define("mylib::cvmm_triton", "(Tensor x, Tensor sel_index, Tensor sel, Tensor keys, ScalarType out_dtype, Tensor out_index) -> Tensor")

@torch.library.impl("mylib::cvmm_triton", "default")
def cvmm_triton(
    x: torch.Tensor,
    sel_index: torch.Tensor,
    sel: torch.Tensor,
    keys: torch.Tensor,
    out_dtype: torch.dtype,
    out_index: torch.Tensor
):
    x = x.flatten(end_dim=-2)
    assert x.shape[-1] == keys.shape[1]

    sel_shape = sel.shape
    sel = sel.flatten()

    M = sel.shape[0]
    O, K, N = keys.shape
    # Allocates output.
    out = torch.empty((M, N), device=x.device, dtype=out_dtype)
    # out = torch.zeros((M, N), device=x.device, dtype=out_dtype)
    # 1D launch kernel where each block gets its own program.

    # expected_m_per_matrix = int(math.ceil(M / O * 1.5))
    # expected_m_per_matrix = M

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )

    out_index_is_none = False
    if out_index.numel() == 1 and out_index == -1:
        out_index_is_none = True

    cvmm_kernel[grid](
        x, keys, out, sel_index, sel, out_index,
        M, N, K,
        x.stride(0), x.stride(1),
        keys.stride(0), keys.stride(1), keys.stride(2),
        out.stride(0), out.stride(1),
        sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),
        out_index_is_none=out_index_is_none,
        float32=out.dtype==torch.float32, allow_tf32=False, #torch.backends.cuda.matmul.allow_tf32
    )

    return out.view(*sel_shape, N)


@torch.library.impl_abstract("mylib::cvmm_triton", cvmm_triton)
def cvmm_triton_abstract(x, sel_idx, sel, keys, out_dtype, out_index):
    sel_shape = sel.shape
    sel = sel.flatten()
    M = sel.shape[0]
    O, K, N = keys.shape
    out = torch.empty((M, N), device=x.device, dtype=out_dtype)
    sel_shape = sel.shape
    return out.view(*sel_shape, N)


torch.library.define("mylib::cvmm_triton_backward", "(Tensor x, Tensor sel_index, Tensor sel, Tensor grads, int n_experts, ScalarType key_dtype, bool op_float16, Tensor out_index) -> Tensor")

@torch.library.impl("mylib::cvmm_triton_backward", "default")
def cvmm_triton_backward(
    x: torch.Tensor,
    sel_index: torch.Tensor,
    sel: torch.Tensor,
    grads: torch.Tensor,
    n_experts: int,
    key_dtype: torch.dtype,
    op_float16: bool,
    out_index: torch.Tensor
):
    x = x.flatten(end_dim=-2)
    x = x.transpose(0, 1)
    grads = grads.flatten(end_dim=-2)
    sel = sel.flatten()
    M, _ = x.shape
    K, N = grads.shape
    out = torch.zeros((n_experts, M, N), device=x.device, dtype=key_dtype)
    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), triton.cdiv(K, META['BLOCK_SIZE_K'] * META['K_BLOCKS'])
    )
    out_index_is_none = False
    if out_index.numel() == 1 and out_index == -1:
        out_index_is_none = True

    cvmm_backward_kernel3[grid](
        x, grads, out, sel_index, sel, out_index,
        M, N, K,
        x.stride(0), x.stride(1),
        grads.stride(0), grads.stride(1),
        out.stride(0), out.stride(1), out.stride(2),
        sel_index.stride(0), sel.stride(0), 0 if out_index_is_none else out_index.stride(0),
        out_index_is_none=out_index_is_none,
        float32_out=out.dtype == torch.float32,
        op_float16=op_float16,
        allow_tf32=False #torch.backends.cuda.matmul.allow_tf32
    )
    return out


@torch.library.impl_abstract("mylib::cvmm_triton_backward", cvmm_triton_backward)
def cvmm_triton_backward_abstract(
    x: torch.Tensor,
    sel_index: torch.Tensor,
    sel: torch.Tensor,
    grads: torch.Tensor,
    n_experts: int,
    key_dtype: torch.dtype,
    op_float16: bool,
    out_index: torch.Tensor
):
    x = x.flatten(end_dim=-2)
    x = x.transpose(0, 1)
    grads = grads.flatten(end_dim=-2)
    M, _ = x.shape
    _, N = grads.shape
    out = torch.zeros((n_experts, M, N), device=x.device, dtype=key_dtype)
    return out


class CVMM(torch.autograd.Function):
    warned = False

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        sel_index: torch.Tensor,
        sel: torch.Tensor,
        keys: torch.Tensor,
        out_index: Optional[torch.Tensor] = None,
        reduction_weight: Optional[torch.Tensor] = None
    ):
        ctx.save_for_backward(x, keys, sel, sel_index, out_index, reduction_weight)
        out_type = torch.float16 if torch.is_autocast_enabled() else x.dtype
        if out_index is None:
            out_index = torch.tensor(-1).cuda()

        res = torch.ops.mylib.cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
        # res = cvmm_triton(x, sel_index, sel, keys, out_type, out_index)
        if reduction_weight is not None:
            res = res.view(*reduction_weight.shape, res.shape[-1])
            res = (reduction_weight.unsqueeze(-2).type_as(res) @ res).squeeze(-2)

        ctx.op_type = out_type
        ctx.keys_type = keys.dtype
        ctx.is_autocast = torch.is_autocast_enabled()
        return res

    @staticmethod
    def backward(ctx, grad_output):
        x, keys, sel, sel_index, out_index, reduction_weight = ctx.saved_tensors
        keys_dt = keys

        # Backward for weight
        if reduction_weight is not None:
            # Project back the grads with he reduction weight, so the grad for the weight matrix is ok
            grad_output_w = reduction_weight.unsqueeze(-1).type_as(grad_output) @ grad_output.unsqueeze(-2)
            # print("no none", grad_output_w.shape)
        else:
            grad_output_w  = grad_output
            # print("none", grad_output_w.shape)

        out_index_is_none = False
        if out_index is None:
            out_index_is_none = True
            out_index = torch.tensor(-1).cuda()

        # grad_w = cvmm_triton_backward(
        #     x,
        #     sel_index,
        #     sel,
        #     grad_output_w,
        #     keys_dt.shape[0],
        #     ctx.keys_type,
        #     ctx.is_autocast,
        #     out_index=out_index
        # ) 

        grad_w = torch.ops.mylib.cvmm_triton_backward(
            x,
            sel_index,
            sel,
            grad_output_w,
            keys_dt.shape[0],
            ctx.keys_type,
            ctx.is_autocast,
            out_index=out_index
        )

        # Backward for input and reduction weight
        grad_w_off = None

        bw_index = sel_index if out_index_is_none else out_index
        bw_index_out = torch.tensor(-1).cuda()
        if reduction_weight is not None:
            # Hack the output indices to emulate repeats
            bw_index_out = bw_index
            bw_index = bw_index // reduction_weight.shape[-1]

        grad_x_full = torch.ops.mylib.cvmm_triton(
            grad_output,
            bw_index,
            sel,
            keys_dt.transpose(1,2),
            ctx.op_type,
            bw_index_out
        )

        grad_x_full = grad_x_full.view(*x.shape[:-1], -1, x.shape[-1])
        if reduction_weight is not None:
            # grad_x_full is the unscaled grad. For the input, we have to scale it, for the reduction wegiht,
            # we have to compute dot products with the input.
            grad_x = (reduction_weight.view(*grad_x_full.shape[:-1]).unsqueeze(-2).type_as(grad_x_full) @ grad_x_full).squeeze(-2)
            grad_w_off = (grad_x_full.type_as(reduction_weight) @ x.unsqueeze(-1).type_as(reduction_weight)).squeeze(-1).view_as(reduction_weight)
        elif grad_x_full.shape[-2] != 1:
            grad_x = grad_x_full.sum(-2)
        else:
            grad_x = grad_x_full

        grad_x = grad_x.view_as(x)

        return grad_x, None, None, grad_w, None, grad_w_off


def cvmm(x: torch.Tensor, sel: Union[torch.Tensor, CVMMSel], keys: torch.Tensor):
    if not isinstance(sel, CVMMSel):
        sel = cvmm_prepare_sel(sel, keys.shape[0])

    return CVMM.apply(x, sel.sel_index, sel.sel, keys, sel.out_index, sel.reduction_weight)


def cvmm_prepare_sel2(sel: torch.Tensor, w: Optional[torch.Tensor] = None) -> CVMMSel:
    # Has multiple selections for each batch element
    n_per_batch = sel.shape[-1]

    # indices = torch.arange(sel.nelement() // n_per_batch, device=sel.device, dtype=torch.int32)
    # indices = indices.repeat_interleave(n_per_batch).flatten()

    fsel = sel.flatten()
    ssel, sel_index = fsel.sort()

    # in_index = indices[sel_index]
    in_index = sel_index // n_per_batch

    return CVMMSel(sel, ssel.view_as(sel), in_index, sel_index, w)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.n_heads = 8
        self.n_experts = 8
        self.expert_size = 64
        self.k_vec_dim = 128
        self.v_dim = 128
        self.keys = torch.nn.Parameter(
            torch.empty(self.n_experts, self.k_vec_dim, self.expert_size)
        )
        self.values = torch.nn.Parameter(
            torch.empty(self.n_experts, self.expert_size, self.v_dim)
        )
        self.expert_sel = torch.nn.Linear(self.k_vec_dim, self.n_experts, bias=False)
        self.sel_activation = torch.nn.Sigmoid()

    def compute_scores(self, input: torch.Tensor, index: CVMMSel) -> torch.Tensor:
        scores = cvmm(input, index, self.keys)
        return scores

    def forward(self, input: torch.Tensor):
        sel = sel_raw = self.expert_sel(input)
        sel = self.sel_activation(sel)
        sel_val, sel_index = sel.topk(self.n_heads, dim=-1, sorted=False)
        # Preprocess the selection indices. They will be needed for both layers and save some time
        sel_indices = cvmm_prepare_sel2(sel_index.int())
        # "Up-projection" layer for each head
        scores = self.compute_scores(input, sel_indices)
        # Down projection layer for each head
        sel_indices = sel_indices.clone()
        sel_indices.reduction_weight = sel_val
        sel_indices.sel_index = sel_indices.out_index
        sel_indices.out_index = None
        out = cvmm(scores, sel_indices, self.values)
        return out


model = Model().to(torch.float16).cuda()
model = torch.compile(model)

torch.manual_seed(0)
n_experts = 8
n_channels = 128
expert_size = 64
bs = 64

device = torch.device("cuda")
dtype = torch.float16

testvec = torch.randn(bs, n_channels, dtype=dtype, device=device)

out = model(testvec)
loss = out.sum()
loss.backward()

print(model.keys.grad.shape)

print(out.shape)

I am using the current nightly build. The exception I get is:

Traceback (most recent call last):
  File "/u/julianb/transformers/src/transformers/models/sigma_moe/triton_src/moe_layer/cvmm.py", line 624, in <module>
    loss.backward()
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/autograd/graph.py", line 681, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 754, in backward
    out = call_compiled_backward()
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 706, in call_compiled_backward
    out = call_func_at_runtime_with_args(
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 417, in _fn
    return fn(*args, **kwargs)
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 863, in __call__
    return self.get_current_callable()(inputs)
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 608, in run
    return model(new_inputs)
  File "/dccstor/broccoli/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 891, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_julianb/xa/cxajzptzw7uzhcqg2mfni6rnorv2jmmlin4yd6qqvdq4saagxfl6.py", line 149, in call
    assert_size_stride(cvmm_triton, (512, 64), (64, 1))
AssertionError: wrong number of dimensions

My guess what is happening is that in the backward pass, the function cvmm_triton gets called with different input shapes, but I don't understand why that is not handled.
Any help would be appreciated.

I can repro this. The inductor generated code is: https://gist.github.com/oulgen/68f41bd6d28e1547a80081097bce9a86

I'll take a look at this but it might be next week or later since I am currently out on vacation

@oulgen any updates on this?

Not yet, I have been preoccupied with a silent accuracy bug. Hoping to take a look at this next week.

@jubueche I took a look at this today and it looks like the bug is not related to triton kernels. So I removed all the user defined triton kernels and the repro still fails with

AssertionError: wrong number of dimensions
on assert_size_stride(cvmm_triton, (512, 64), (64, 1))

The repro passes on eager mode and fails in inductor. So it looks like there's an overall PT2 issue here nothing to do with user defined triton kernels.

https://gist.github.com/oulgen/3d40cfc65077179654f3d0f71d706008

Removed triaged so that the oncall can re-triage this for faster turn around time

The repro passes on eager mode and fails in inductor. So it looks like there's an overall PT2 issue here nothing to do with user defined triton kernels.

Removed triaged so that the oncall can re-triage this for faster turn around time

Not familiar with that. Should I open a new issue somewhere else or is someone else looking at this?

I removed the triaged tag so the PT2 oncall will retriage it. I will make sure it gets assigned. No need to create a new issue.

Tentatively assigned to Brian.

https://gist.github.com/oulgen/3d40cfc65077179654f3d0f71d706008

The repro fails with backend="eager", so it may not be an inductor issue

I'm going to temporarily un-assign myself. This issue repro's with dynamo, so it is (likely?) a fake tensor issue (between other commitments and PTO, I won't be able to look at this for at least 3-4 weeks).

If someone else can take a look before then feel free - otherwise, I will look again when I have more bandwidth

This no longer repro's as an assert_size_stride error, but now fails in the custom triton kernel. cc @oulgen would you take a look at this ?

Okay, using @oulgen's gist without custom triton it no repros. It does fail on both eager and aot_eager.

TL;DR, since the thread is long:

I cleaned up all the repros. You can find the latest repros here.

It's unclear if this is due to user error, a triton bug, or a pytorch error.

To clarify @zou3519's comment, I looked into the gist using custom op and it had meta incorrectly specified thus the assertion error with stride. At this point I think the issue is in either in custom ops or triton op lowering. deferring to @zou3519 and @oulgen for custom op/triton op support.

From internal triage meeting: oulgen still planning to look at this

cc. @oulgen curious if any update on this? thanks!

I've been busy with some other tasks, haven't had a chance to look at this.

cc. @aakhundov would you be interested in picking this up? (from triage meeting)