alxndrTL/mamba.py

Values of deltaA are very large

anhtienng opened this issue · 5 comments

Hi,

The value of A is very large after discretization.
deltaA = torch.exp(delta.unsqueeze(-1) * A)

The big value makes the loss NaN.
I also found the similar problem in the original mamba repo, but I can't find the solution.
I have try the ZOH discretization to avoid the exp function, but it still exits.

Do you know how to solve it ?
Thank you.

The author of Jamba (hybrid of Mamba & attention) apply inner layernorms to dt (as well as B and C).
I've implemented this in the mamba.py file :
https://github.com/alxndrTL/mamba.py/blob/eddec5da76da6594850ea86a7afa56c9ab6b5ac7/mambapy/mamba.py#L246C8-L246C58

Maybe this will help ?

The layernorms is not applied for A in the code now.

So you mean I could try to apply it for A ?

No but it is applied to delta, which is used to compute deltaA, which is very big in your case so that's why I proposed this

I found the problem, it's because I forgot to use softplus for delta after the projection.
My bad.

Thank you very much.

Cool it worked out!