Flops calculation
Closed this issue · 6 comments
Dear authors,
Thanks for your amazing work!
I have a small question, is it possible to provide the code of FLOPs calculation or explain how to compute the FLOPs of MambaVision. Thanks!
Hi @AndyCao1125
You can use ptflops for this purpose and install it using pip install ptflops
.Following snippet should be useful for FLOPs/Params calculations:
from ptflops import get_model_complexity_in
from mamba_vision import *
resolution = 224
model = mamba_vision_T().cuda().eval()
Flops, params = get_model_complexity_info(model, tuple([3, resolution, resolution]), as_strings=False, print_per_layer_stat=False, verbose=False)
print(f"Model stats: GFlops: {Flops*1e-9}, and (M) params: {params*1e-6}")
Hope it helps
Dear authors,
Thanks for your prompt reply! I've sucessfully tested the Flops and Params of the Mamba models with no error. However, I have a few more questions:
As discussed in (state-spaces/mamba#303) and (state-spaces/mamba#110) in the original mamba github issues, I found that the Flops of SSM usually cannot be computed with ordinary libraries (e.g., thop) due to the special selective scan mechanism. The author Albert Gu gave a theoretical flops of
I noticed that although the ptflops library you provided doesn't report errors when calculating SSM, I would like to ask if this calculation actually counts the number of special operations inside the selective scan? Many thanks!
Hi @AndyCao1125
We did not use Albert's estimated formula. But it should not change the reported numbers using ptflops as the SSM part is not the most compute intensive operation, even in stage 3 and 4 where it is still dominated by self-attention layers.
I assume L and N in 9LN
formula denote the sequence length and state dimension size which are 196
and 16
respectively for stage 3 and 49
and 16
for stage 4.
In mamba_vision_B
for example, we have 5 and 3 SSM layers in stage 3 and 4. So, assuming ptflops still does not account for the SSM part, the total (added Flops) should be:
(5 * 196 * 16 + 3 * 49 * 16 ) * 9e-9 = 0.000162288 GFLOPs
We have reported the number of FLOPs to be 15.0 GFLOPs
for mamba_vision_B
. So as a result, this still does not change the reported values.
Please feel free to let us know if the above calculations need to be modified to take into account anything that we did not consider.
Ali
Hi @ahatamiz
Thanks! Sorry for the omitted part of the statement I just made. The computation of a complete selective scan should be equal to d_model
, d_state
(=16 as default in the Mamba setting), and
, the additional flops=
Thus, the revised flops should be:
(5 * 196 * 16 * 16 + 3 * 49 * 16 * 16 ) * 9 * 1e-9 + (additional flops from B and Z) >= 0.00259 GFLOPs
Although the overall GFLOPs for mamba_vision_B
are unchanged since we keep only a few significant digits, there may be an effect on the integer bits of GFLOPs for mamba_vision_L/L2
. Moreover, for larger images (e.g., super-resolution datasets), this effect becomes more significant.
Therefore, I think that accurately calculating the flops of a mamba-based model is a relatively troublesome process due to the fact that the selective scan method needs to be calculated manually. In addition to the selective scan operation, Mamba module itself contains linear
or conv1d
modules that need to be added to the calculation of flops.
Could we try to realize the accurate calculation of flops for mamba-based models? (If needed, I'm willing to contribute with u :)
Thanks @AndyCao1125 for clarification. I agree that calculating FLOPs for mamba-based are quite tricky. But since we are reporting GFLOPs with a few significant digits, the extra added value won't change things much at least in our configuration. If the sequence length increases by 100 times, the added FLOPs is still 0.2 GFLOPs
due to SSM part in the above example.
But other modules such as linear
and conv1d
are already taken into account -- imagine a hook-based implementation which registers all such layers.
Thanks for your reply!