Add function decoration FuncParamAttr support in SPIRV dialect.
Opened this issue · 1 comments
chengjunlu commented
FuncParamAttr is not added to SPIR-V dialect yet which is required to link some functions in the Intel Math Libraries.
The import signature is:
OpName %__devicelib_imf_float2bfloat16 "__devicelib_imf_float2bfloat16"
OpDecorate %__devicelib_imf_float2bfloat16 LinkageAttributes "__devicelib_imf_float2bfloat16" Export
OpDecorate %__devicelib_imf_float2bfloat16 FuncParamAttr Zext
The import function declaration in SPIRV dialect is:
spirv.func @__devicelib_imf_float2bfloat16(f32) -> i16 "Inline" attributes {FuncParamAttr = "Zext", libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_float2bfloat16", "Import"]}
The error log when to serialize the SPIRV dialect:
error: unhandled decoration FuncParamAttr
chengjunlu commented
Here is the SPIRV dialect to link the function with FuncParamAttr = "Zext"
decoration.
// -----// IR Dump After CSE (cse) //----- //
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
spirv.GlobalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spirv.ptr<vector<3xi64>, Input>
spirv.func @__devicelib_imf_bfloat162float(i16) -> f32 "Inline" attributes {libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_bfloat162float", "Import"]}
spirv.func @__devicelib_imf_float2bfloat16(f32) -> i16 "Inline" attributes {FuncParamAttr = "Zext", libname = "libdevice", libpath = "", linkage_attributes = ["__devicelib_imf_float2bfloat16", "Import"]}
spirv.func @kernel_0d1d2d(%arg0: !spirv.ptr<f32, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg1: !spirv.ptr<i16, CrossWorkgroup> {tt.divisibility = 16 : i32}, %arg2: !spirv.ptr<i16, CrossWorkgroup> {tt.divisibility = 16 : i32}) "None" attributes {noinline = false, spirv.entry_point_abi = #spirv.entry_point_abi<>, sym_visibility = "public"} {
%__builtin_var_LocalInvocationId___addr = spirv.mlir.addressof @__builtin_var_LocalInvocationId__ : !spirv.ptr<vector<3xi64>, Input>
%0 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
%1 = spirv.CompositeExtract %0[0 : i32] : vector<3xi64>
%2 = spirv.SConvert %1 : i64 to i32
%cst32_i32 = spirv.Constant 32 : i32
%3 = spirv.UMod %2, %cst32_i32 : i32
%4 = spirv.UDiv %2, %cst32_i32 : i32
%cst4_i32 = spirv.Constant 4 : i32
%5 = spirv.UMod %4, %cst4_i32 : i32
%cst128_i32 = spirv.Constant 128 : i32
%6 = spirv.UMod %3, %cst128_i32 : i32
%cst1_i32 = spirv.Constant 1 : i32
%7 = spirv.IMul %5, %cst32_i32 : i32
%8 = spirv.IAdd %6, %7 : i32
%9 = spirv.IMul %cst1_i32, %8 : i32
%10 = spirv.Undef : !spirv.struct<(i32)>
%11 = spirv.Undef : !spirv.struct<(!spirv.ptr<i16, CrossWorkgroup>)>
%12 = spirv.PtrAccessChain %arg1[%9] : !spirv.ptr<i16, CrossWorkgroup>, i32
%true = spirv.Constant true
%13 = spirv.Undef : i16
spirv.BranchConditional %true, ^bb1, ^bb2(%13 : i16)
^bb1: // pred: ^bb0
%14 = spirv.Load "CrossWorkgroup" %12 : i16
spirv.Branch ^bb2(%14 : i16)
^bb2(%15: i16): // 2 preds: ^bb0, ^bb1
%16 = spirv.Undef : !spirv.struct<(i16)>
%17 = spirv.PtrAccessChain %arg2[%9] : !spirv.ptr<i16, CrossWorkgroup>, i32
spirv.BranchConditional %true, ^bb3, ^bb4(%13 : i16)
^bb3: // pred: ^bb2
%18 = spirv.Load "CrossWorkgroup" %17 : i16
spirv.Branch ^bb4(%18 : i16)
^bb4(%19: i16): // 2 preds: ^bb2, ^bb3
%20 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%15) : (i16) -> f32
%21 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%19) : (i16) -> f32
%22 = spirv.FAdd %20, %21 : f32
%23 = spirv.FunctionCall @__devicelib_imf_float2bfloat16(%22) : (f32) -> i16
%24 = spirv.Undef : !spirv.struct<(!spirv.ptr<f32, CrossWorkgroup>)>
%25 = spirv.PtrAccessChain %arg0[%9] : !spirv.ptr<f32, CrossWorkgroup>, i32
%26 = spirv.FunctionCall @__devicelib_imf_bfloat162float(%23) : (i16) -> f32
%27 = spirv.Undef : !spirv.struct<(f32)>
%28 = spirv.Load "Input" %__builtin_var_LocalInvocationId___addr : vector<3xi64>
%29 = spirv.CompositeExtract %28[0 : i32] : vector<3xi64>
%30 = spirv.SConvert %29 : i64 to i32
spirv.BranchConditional %true, ^bb5, ^bb6
^bb5: // pred: ^bb4
%31 = spirv.Bitcast %26 : f32 to i32
%32 = spirv.Bitcast %25 : !spirv.ptr<f32, CrossWorkgroup> to !spirv.ptr<i32, CrossWorkgroup>
spirv.Store "CrossWorkgroup" %32, %31 : i32
spirv.Branch ^bb6
^bb6: // 2 preds: ^bb4, ^bb5
spirv.Return
}
}