Cambricon/triton-linalg

[Question] Linalg_ext.gather behavior different from tensor.gather

Closed this issue · 8 comments

In the following example given in /test/Dialect/LinalgExt/ops.mlir:
// CHECK: linalg_ext.gather
func.func @gather_tensor_i8_indice(%indices : tensor<4x1xi8>, %window: tensor<4x2x4xf32>, %data: tensor<16x8xf32>, %mask: tensor<4xi1>) -> tensor<4x2x4xf32> {
%gather = linalg_ext.gather
dimension_map = [1]
ranged_data(true) signed_indice(false)
ins(%data, %indices, %mask: tensor<16x8xf32>, tensor<4x1xi8>, tensor<4xi1>)
outs(%window: tensor<4x2x4xf32>) {
^bb0(%arg0 :f32, %arg1: f32):
linalg_ext.yield %arg0 : f32
} -> tensor<4x2x4xf32>
return %gather : tensor<4x2x4xf32>
}

A similar gather op in tensor dialect should give result as a tensor<4x16xf32> or tensor<4x16x1xf32>, being 1 the gather dimension in %data. Why does linalg_ext.gather produce a tensor<4x2x4xf32>? I do notice that ^bb0 block reduces the output size by a factor of 2, but this still leads to a tensor<4x8xf32> rather than tensor<4x2x4xf32>.

Our definition of linalg_ext.gather is based on hlo.gather, which provides support for an index value that points to a block of data. With this knowledge, it can be optimized into more efficient hardware instructions in certain scenarios . In contrast, tensor.gather currently merely supports an index that points to a single data element or the entire dimension, at the same time, it does not support masking..

Taking the example mentioned above, it shows that there are 4 indices, each index stores the linearized offset of the entire data, retrieving a 2x4 block of data from the data of 16x8. Hence, the output shape is 4x2x4.

项目中涉及对linalg_ext.gather的转换与递降,发现linalgExtOps.md中对linalg_ext.gather的部分描述可能存在令人困惑的地方,具体如下:
SmallVector<int64_t> wholeIdx(n); for (int i = 0; i < dimension_map.size(); ++i) { wholeIdx[dimension_map[i]] = indice[i]; }
,其中wholeIdx的元素是int64_t类型的标量值,而indice[i]表示第i维度的索引,可能是vector类型,也可能是tensor等类型。

相关代码实现如下:
for (int i = 0; i < dimension_map.size(); ++i) { auto index = rewriter.create<mlir::arith::ConstantIndexOp>(loc, i); auto indiceElem = rewriter.create<mlir::tensor::ExtractOp>(loc, indice, mlir::ValueRange{index}); wholeIdx[dimension_map[i]] = indiceElem.getResult().getDefiningOp<mlir::arith::ConstantIndexOp>().value(); }
编译通过,然而运行时,由于不同类型的操作数赋值操作,导致段错误。
image

图中,红色箭头涉及的代码行:wholeIdx[dimension_map[i]] = indiceElem.getResult().getDefiningOpmlir::arith::ConstantIndexOp().value();

请明确这个地方赋值运算涉及的操作数类型 ,或提供相关代码demo,谢谢!

@sethbrin

@xuel0707 The type of Indices is constrained within the td definition to be TensorIRMemref, whose type is either signless integer or index type, which is consistent with the requirements of the tensor dialect.

The error mentioned appears to be a type mismatch, and to accurately determine the cause, it would be helpful to provide the erroneous IR for analysis.

You can Refer to the ir example here https://github.com/Cambricon/triton-linalg/blob/master/test/Dialect/LinalgExt/ops.mlir#L243.

@xuel0707 The type of Indices is constrained within the td definition to be TensorIRMemref, whose type is either signless integer or index type, which is consistent with the requirements of the tensor dialect.

The error mentioned appears to be a type mismatch, and to accurately determine the cause, it would be helpful to provide the erroneous IR for analysis.

You can Refer to the ir example here https://github.com/Cambricon/triton-linalg/blob/master/test/Dialect/LinalgExt/ops.mlir#L243.

是的,我使用了这个IR作为测试分析,还是运行时报错,指向了同一行代码:wholeIdx[dimension_map[i]] = indiceElem.getResult().getDefiningOp<mlir::arith::ConstantIndexOp>().value();

tensor方言中,操作数indices是ranked tensor of signless integer or index values
上述示例https://github.com/Cambricon/triton-linalg/blob/master/test/Dialect/LinalgExt/ops.mlir#L243.
中indice是2-D 的tensor<4x1xi32>。因此,linalgExtOps.md中对linalg_ext.gather的描述:wholeIdx[dimension_map[i]] = indice[i]是错误的;

请给出合理的计算公式,谢谢! @sethbrin

@xuel0707 The dimension_map in the Triton descent process is not meaningful; it is primarily designed to accommodate the start_index_map attribute from MHLO.

The gather/scatter operators are defined in a complex manner to be compatible with both Triton and MHLO. We will consider how to handle the operator descriptions in subsequent versions and welcome submissions of PRs for updates.

Regarding the issue with the segment error you mentioned, the definingOp of indiceElem.getResult should be tensor::ExtractOp instead of arith::ConstantIndexOp. The absence of a null check here could result in calling the value member function of arith::ConstantIndexOp with a null pointer, as the null pointer has no data, which will cause the program to hang when attempting to read its member variable.

for (int i = 0; i < dimension_map.size(); ++i) { 
    auto index = rewriter.create<mlir::arith::ConstantIndexOp>(loc, i); 
    auto indiceElem = rewriter.create<mlir::tensor::ExtractOp>(loc, indice, mlir::ValueRange{index}); 
    // indiceElem.getResult().getDefiningOp<mlir::arith::ConstantIndexOp>() is nullptr
    wholeIdx[dimension_map[i]] = indiceElem.getResult().getDefiningOp<mlir::arith::ConstantIndexOp>().value(); 
}

@xuel0707 td 的表述确实任意混淆。我说一下个人愚见,首先就 linalg_ext.gather op 的input operand来说,可以是 2个(input, indices) 或 3个(input, indices, mask)。对,这里我觉得用 indices 更好区分说明。

include/triton-linalg/Dialect/LinalgExt/IR/LinalgExtOps.td 中关于 gather 算子的描述可以总结为:

- input has shape [i0, i1, ..., in-1]
- indices has shape [Batch0, Batch1, ..., Batchm-1, k]
  - 一共有 [Batch0, Batch1, ..., Batchm-1] 组 indice 
  - 每组 indice 有 k 个数: [idx0, idx1, ..., dixk],所以 k 一定不能为 dynamic
- mask has shape [Batch0, Batch1, ..., Batchm-1]

- init
  - shape [Batch0, Batch1, ..., Batchm-1, o0, o1, ..., on-1].
  - rank >= 2
  - mask 和 init 的 `前 indices.getRanke() - 1`(又称batchNum) 个 dimSize 相同
  - init[idx + batchNum] <= inputType[idx]
  - init 是从 input 中提取出 [Batch0, Batch1, ..., Batchm-1] 组 形状为 [o0, o1, ..., on-1] 的数据

计算行为:

for (i0 = 0; i0 < Batch0; ++i0) {
  ...
  for (im-1 = 0; im-1 < Batchm-1; ++im-1) { // [Batch0, Batch1, ..., Batchm-1] 组
    indice = indices[i0, ..., im-1]; // 每组 indice 数据为 k 个数,即 [idx0, idx1, ..., dixk]
    if (mask[i0, ..., im-1]) { // 判断该组是否需要被 mask
      // if region is empty, only copy will apply on init.
      computation(input[indice], init[i0, ..., im-1]);
    }
  }
}

此时,linalg_ext.gather 还有一个 dimension_map 参数,我理解这是给 indice 做 transpose 的。即给长度为 k 的数组 [idx0, idx1, ..., dixk]做 permutation。所以 dimension_map 存在约束 dimension_map.size() = k,需要包含真实使用的 realIndice 相对 indice 的排布。后续在 computation(input[indice], init[i0, ..., im-1]) 的计算过程,用的就是该 realIndice。

SmallVector<int64_t> realIndice(n);
for (int i = 0; i < dimension_map.size(); ++i) {
  realIndice[dimension_map[i]] = indice[i];
}

看起来在 Triton-Linalg 项目中 build linalg_ext.gather 时,直接给定了 dimension_map = [0],也就是说这个参数相当于不起作用,没有对 indice 进行 transpose。

tensor方言中,操作数indices是ranked tensor of signless integer or index values
上述示例https://github.com/Cambricon/triton-linalg/blob/master/test/Dialect/LinalgExt/ops.mlir#L243.
中indice是2-D 的tensor<4x1xi32>。因此,linalgExtOps.md中对linalg_ext.gather的描述:wholeIdx[dimension_map[i]] = indice[i]是错误的;

所以,你所说的 indice是2-D 的tensor<4x1xi32> 其实是我上文所描述的作为 input operand的 indices , 后面 wholeIdx[dimension_map[i]] = indice[i] 中的 indice 其实是 indices 的最后一维(innermost dimension)

indices has shape [Batch0, Batch1, ..., Batchm-1, k]
  - 一共有 [Batch0, Batch1, ..., Batchm-1] 组 indice 
  - 每组 indice 有 k 个数: [idx0, idx1, ..., dixk]
%107 = linalg_ext.gather dimension_map = [0] ranged_data(false) signed_indice(true) ins(%105, %expanded_24, %collapsed_25 : tensor<9223372036854775807xf16>, tensor<1024x1xi32>, tensor<1024xi1>) outs(%106 : tensor<1024x1xf16>) {
      ^bb0(%arg13: f16, %arg14: f16):
        linalg_ext.yield %arg13 : f16
      } -> tensor<1024x1xf16>

转换为如下:
%gather = tensor.gather %105[%expanded_24] gather_dims([0]) : (tensor<9223372036854775807xf16>, tensor<1024x1xi32>) -> tensor<1024x1xf16>
上述转换是否语义等价? @sethbrin @tfruan2000

@xuel0707 是等价的