A simple cuda code help you understand how flashattention work
No dropout , No mask , Fixed input size
-
- cuda implement
- qkvo struct
- qkv global memory load
- qkv shared memory store
- qkv shared memory load
- gemm
- softmax
- c++ api
- c++ compare with right answer
- cuda implement
-
- python warp
- setup.py
- benchmark.py
- python warp