/mamba-simple

An annotated python implementation of selective_scan_cuda

Primary LanguagePython

This is an annotaed implementation of selective_scan_cuda from Mamba. This is not runnable code, but a reference for what is happening in the CUDA kernel. I wrote this in Python for people who don't know C++. There are some slight differences between the Python and C++ code, but the general idea is the same.

Additionally, I wrote comments in the code to explain what is going on and why they chose to do some things. A basic understanding of what a CUDA kernel is would be helpful, though you can still get a general idea of what is going on without it.

Something Things to Note:

  • Pointers don't exist in Python, thus some of the logic for that has been removed.
  • In the C++ implementation we need to update the location of the pointer after every chunk, however, in the annotated Python version I do not do this. (So don't get confused when it looks like every chunk is iterating over the same data)
  • Templates don't exist in Python, thus I kinda get around this by having functions be classes. But don't be fooled, these would be functions in C++
  • This is more like pseudo-code, which is written in a pythonish language (The comments are where I try to explain what we are doing and why)