intel/mlir-extensions

Add function decoration FuncParamAttr support in SPIRV dialect.

Opened this issue · 1 comments

FuncParamAttr is not added to SPIR-V dialect yet which is required to link some functions in the Intel Math Libraries.

The import signature is:

OpName %__devicelib_imf_float2bfloat16 "__devicelib_imf_float2bfloat16" 
OpDecorate %__devicelib_imf_float2bfloat16 LinkageAttributes "__devicelib_imf_float2bfloat16" Export 
OpDecorate %__devicelib_imf_float2bfloat16 FuncParamAttr Zext 

The import function declaration in SPIRV dialect is:

 spirv.func @__devicelib_imf_float2bfloat16(f32) -> i16 "Inline" attributes {FuncParamAttr = "Zext", libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_float2bfloat16", "Import"]}

The error log when to serialize the SPIRV dialect:
error: unhandled decoration FuncParamAttr

Here is the SPIRV dialect to link the function with FuncParamAttr = "Zext" decoration.


// -----// IR Dump After CSE (cse) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
  spirv.func @__devicelib_imf_bfloat162float(i16) -> f32 "Inline" attributes {libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_bfloat162float", "Import"]}
  spirv.func @__devicelib_imf_float2bfloat16(f32) -> i16 "Inline" attributes {FuncParamAttr = "Zext", libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_float2bfloat16", "Import"]}
  spirv.func @kernel_0d1d2d(%arg0: !spirv.ptr<f32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i16, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i16, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {noinline = false, spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
    %__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
    %0 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %1 = spirv.CompositeExtract %0[0 : i32] : vector<3xi64>
    %2 = spirv.SConvert %1 : i64 to i32
    %cst32_i32 = spirv.Constant 32 : i32
    %3 = spirv.UMod %2, %cst32_i32 : i32
    %4 = spirv.UDiv %2, %cst32_i32 : i32
    %cst4_i32 = spirv.Constant 4 : i32
    %5 = spirv.UMod %4, %cst4_i32 : i32
    %cst128_i32 = spirv.Constant 128 : i32
    %6 = spirv.UMod %3, %cst128_i32 : i32
    %cst1_i32 = spirv.Constant 1 : i32
    %7 = spirv.IMul %5, %cst32_i32 : i32
    %8 = spirv.IAdd %6, %7 : i32
    %9 = spirv.IMul %cst1_i32, %8 : i32
    %10 = spirv.Undef : !spirv.struct<(i32)>
    %11 = spirv.Undef : !spirv.struct<(!spirv.ptr<i16, CrossWorkgroup>)>
    %12 = spirv.PtrAccessChain %arg1[%9] : !spirv.ptr<i16, CrossWorkgroup>, i32
    %true = spirv.Constant true
    %13 = spirv.Undef : i16
    spirv.BranchConditional %true, ^bb1, ^bb2(%13 : i16)
  ^bb1:  // pred: ^bb0
    %14 = spirv.Load "CrossWorkgroup" %12 : i16
    spirv.Branch ^bb2(%14 : i16)
  ^bb2(%15: i16):  // 2 preds: ^bb0, ^bb1
    %16 = spirv.Undef : !spirv.struct<(i16)>
    %17 = spirv.PtrAccessChain %arg2[%9] : !spirv.ptr<i16, CrossWorkgroup>, i32
    spirv.BranchConditional %true, ^bb3, ^bb4(%13 : i16)
  ^bb3:  // pred: ^bb2
    %18 = spirv.Load "CrossWorkgroup" %17 : i16
    spirv.Branch ^bb4(%18 : i16)
  ^bb4(%19: i16):  // 2 preds: ^bb2, ^bb3
    %20 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%15) : (i16) -> f32
    %21 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%19) : (i16) -> f32
    %22 = spirv.FAdd %20, %21 : f32
    %23 = spirv.FunctionCall @__devicelib_imf_float2bfloat16(%22) : (f32) -> i16
    %24 = spirv.Undef : !spirv.struct<(!spirv.ptr<f32, CrossWorkgroup>)>
    %25 = spirv.PtrAccessChain %arg0[%9] : !spirv.ptr<f32, CrossWorkgroup>, i32
    %26 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%23) : (i16) -> f32
    %27 = spirv.Undef : !spirv.struct<(f32)>
    %28 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
    %29 = spirv.CompositeExtract %28[0 : i32] : vector<3xi64>
    %30 = spirv.SConvert %29 : i64 to i32
    spirv.BranchConditional %true, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    %31 = spirv.Bitcast %26 : f32 to i32
    %32 = spirv.Bitcast %25 : !spirv.ptr<f32, CrossWorkgroup> to !spirv.ptr<i32, CrossWorkgroup>
    spirv.Store "CrossWorkgroup" %32, %31 : i32
    spirv.Branch ^bb6
  ^bb6:  // 2 preds: ^bb4, ^bb5
    spirv.Return
  }
}