mini-flashattention

A simple cuda code help you understand how flashattention work

No dropout , No mask , Fixed input size

todo list

    • cuda implement
      • qkvo struct
      • qkv global memory load
      • qkv shared memory store
      • qkv shared memory load
      • gemm
      • softmax
    • c++ api
    • c++ compare with right answer
    • python warp
      • setup.py
      • benchmark.py