Dao-AILab/flash-attention

The byzantine copy of Tensor O

phantaurus opened this issue · 4 comments

Hello!

Could you help me understand why for the output tensor O, we first copy it from register to shared memory, then copy it from shared memory to register, then from register to global memory?

I think the first copy from register to shared memory is because MMA instruction generates scattered outputs in each thread. So we have to call make_tiled_copy_C(AtomLayoutO, MMA) to generate the index mapping of copy and use DefaultCopy since each the data generated by MMA instruction is not continuous.

But I'm not quite sure why we need to copy the data back from shared memory to register. It looks that we can just directly copy the data from shared memory to global memory.

I can understand that GmemTiledCopyO is still DefaultCopy, as the shared memory layout is swizzled (in order to avoid bank conflicts in the register-to-smem copy), but the global memory layout should be continuously, so we are unable to copy a continuous 128-bit.

But could you help me understand that, from shared memory, why would we copy to register first, then copy the register values back to global memory? Is it because we want to make the data continuous on register, so that they can be copied to global memory more efficiently? Does it mean that in this way, we can guarantee that every shared memory access, no matter being used as destination or source, does not create bank conflicts, and that every global memory access is coalesced?

Thank you so much!

Yes it's all for the efficiency of copying: coalesced gmem write and avoiding bank conflict on smem. This is typical, e.g. most matmul implementations with tensor cores will do this (you can check cutlass, the triton compiler does this behind the scene too I think)

  1. registers (Mma layout) -> smem, typically using stmatrix (i.e. STSM instruction) on Hopper. If stmatrix is not available we can use DefaultCopy.
  2. smem -> registers, typically loading 16 bytes per instruction (LDS.128 instruction). This is to prepare for the next step.
  3. registers -> gmem. We want to make sure this is coalesced. Will use the STG.128 instruction.

However, this is not strictly necessary. On Hopper, with TMA, instead of smem -> registers -> gmem, you can use TMA to store smem -> gmem directly. We do use this in FA3.
Another option is to not do any of these (1, 2, 3) and store directly from registers to gmem. We sometimes do this if the epilogue (storing O) doesn't take much time. Cutlass calls this "no smem epilogue".

Thank you so much for your reply!

I'm still trying to understand why copying data directly from shared memory to global memory might cause shared memory bank conflicts or affect global memory coalescing, such that we have to copy to register first.

For shared memory access, writing to registers or global memory doesn’t appear to influence its access pattern in a way that would change the number of bank conflicts.

For global memory, each thread already processes 128 bits at a time. Does it make a difference if thread 0 writes to global memory bits 128–256 while thread 1 writes to bits 0–127, as opposed to a strictly increasing mapping where thread 0 writes to bits 0–127 and thread 1 writes to bits 128–256, and so on?

There's no way to copy from smem -> gmem except TMA

Ah, I see! Thank you so much for clarifying this. Direct shared memory to global memory transfers are only supported since the Hopper architecture. In cutlass, even when we use cute::copy(), internally the operation still routes through registers.