alxndrTL/mamba.py

Cuda Version

EddieEduardo opened this issue · 5 comments

Hi, great work!

How to enable cuda because I found:

if self.config.use_cuda:
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # I did not find mamba_ssm in this repo
self.selective_scan_cuda = selective_scan_fn
except ImportError:
print("Failed to import mamba_ssm. Falling back to mamba.py.")
self.config.use_cuda = False

Looking forward to your response.

Thanks !

Hello, thanks!

To enable the CUDA version you have to set use_cuda to True when creating a MambaConfig.
Make sure you have installed in CUDA version before doing so (https://github.com/state-spaces/mamba/)

Thanks for reply.

Is there any performance discrepancy with and without using CUDA version ?

Yes, that's the first graph on the README ;) There is also a detailed section also on the README : https://github.com/alxndrTL/mamba.py?tab=readme-ov-file#performances
Note also that memory consumption is considerably higher with the non CUDA version (having fused operations really helps to manage memory efficiently).

Hope this helps.

Thanks!!! Have you tested other metrics except for Time? For example, accuracy

Memory is to be tested precisely.
As for accuracy, of course both versions are numerically equivalent.