alxndrTL/mamba.py

delta question

AliYoussef97 opened this issue · 5 comments

Hello,

Thank you for the amazing work!

I had a question regadring the latest commit. If the following dot product was added to the ssm function, is there a a specific reason why it was not in the ssm_step function as well?

Moreover the softplus in the ssm function uses the bias of dt_proj, while in ssm_step uses the previous implementation delta = F.softplus(self.dt_proj(delta))

Lastly, should D here have a _no_weight_decay simlar to A_log?


Edit: If the new modifications to delta was not added to ssm_step intentionally, does that mean that during inference I have to use the step and can not use forward as well?


Edit2: If in the forward function, I do a birectional forward (as in Vision Mamba), such as:

output = self.mamba_block(self.norm(x))
x_flip = x.flip([1])
output_flip = self.mamba_block(self.norm(x_flip))
output += output_flip
return output + x

However, in the step, a each directional forward will return its own cache, which I am not sure how to handle exactly as unfortunately, I do not fully understand the cache mechanisim (h,input).

Apologies for the long question.

Thank you!

Hello, thank you for taking interesting in my work!

Yes, I just modified the way delta is computed and I realize now I should have added comments to make the thing clear.
Before, this was the code used to compute delta :

delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj(delta)) # (B, ED)

and it is also used currently in the step function.

Now, the dt_proj is computed in multiple steps : first the matrix multiply self.dt_proj.weight @ delta.transpose(1, 2), and then later the bias + softplus operations. By later I mean in this call :

self.selective_scan_cuda(x, delta, A, B, C, D, z=z, delta_softplus=True, delta_bias=self.dt_proj.bias.float())

or in this snippet if cuda is not used :

delta = delta.transpose(1, 2)
delta = F.softplus(delta + self.dt_proj.bias)

Why decomposing the delta computation ? To make use of the fact that selective_scan_cuda fuses the bias + softplus operations. So in main program (where we don't do if we do cuda or no) we just do the matrix multiplication. We do the rest later.

Because selective_scan_cuda is not used at inference, in the step function, we do all 3 in one line just as before.

For D needing the _no_weight_decay attribute, yes thank you I didn't see that!

Concerning the bidirectional forward, I'm not sure how would you do inference in both directions ?
I tried to explain the cache mechanism with a drawing :
IMG_94BA564DFD6B-1
The cache is in violet in the drawing. It is stored by the layer at timestep t to be used again when processing timestep t+1.
Hope this is clear !

Why decomposing the delta computation ? To make use of the fact that selective_scan_cuda fuses the bias + softplus operations. So in main program (where we don't do if we do cuda or no) we just do the matrix multiplication. We do the rest later.

Because selective_scan_cuda is not used at inference, in the step function, we do all 3 in one line just as before.

This makes so much sense thank you so much, so essentially in forward it is computed using the weight and bias instead of a Linear layer directly to make way for selective_scan_cuda.

For D needing the _no_weight_decay attribute, yes thank you I didn't see that!

My pleasure!

Concerning the bidirectional forward, I'm not sure how would you do inference in both directions ?
I tried to explain the cache mechanism with a drawing :

Thank you so much for the drawing, this makes a lot of sense. My only issue is that I do not really get the difference between using the output y1 as input for next layer which is the normal forward pass, and using h1 (state during ssm) and x1 (input after inner projection), since during the step the input x after projection and the previous x after concatinated. So essentially, I do not really understand the benefit of using that over the normal forward pass? I read the explanation a few times I am jsut slightly confused.

Thank you!

My only issue is that I do not really get the difference between using the output y1 as input for next layer which is the normal forward pass, and using h1 (state during ssm) and x1 (input after inner projection)

y1 is used as input for the next layer in both scenarios, forward and step.
While the cache (h, last d_conv-1 inputs) is used by a layer for the same layer later (you have one cache object per layer).

So essentially, I do not really understand the benefit of using that over the normal forward pass?

Using the step function corresponds to the scenario in which you want to call the forward function multiples times in a row, each time with a new input along the time dimension.
ie, first input is (B, 1, D), second input is (B, 2, D), third input is (B, 3, D)...
In this scenario, the step function is used instead of the forward function because it stores the necessary variables for the next call (the cache), instead of recomputing them.
Concretely when Mamba is a language model, this corresponds to sampling from the model (=doing inference). But for your case with Vision Mamba, I'm not sure how inference is done.

Hope this helps!

@alxndrTL This makes so much sense, thank you for the great explanation!

The Vision Mamba block is essentially the same as the Mamba block, which I assume the inference should relatively be similar (This assumtion is by comparing the code from Vim with Mamba). However, in Vim the SSM is computed in a forward and a backward manner. For simplicity:

#Forward 
output_0_f = self.mamba_block(self.norm(x))

#Backward
x_flip = x.flip([1])
output_b = self.mamba_block(self.norm(x_flip))

# Output projection + residual connections
output = self.mamba_block.out_proj(output_f + output_b.flip([1])) + x

In that case, in the step function, both forward and backward will return their own cache, but the input for the next layer would be the combination of the forward and backward output, which is quite confusing on how to handle the cache in that case. I will stick the the forward pass for now, and hopefully I can figure it out how to utlise the step function during inference

In any case, thank you so much for the amazing explanation, I finally understand the cache mechanisim😆!

Cool haha !
Yes ok so maybe the step function isn't what you want for inference here because as I said it is designed for the scenario where you don't have all the inputs at once. And in your case it seems you do have all the inputs at once (you can't do the bidirectional forward if you don't have them).

Good luck!