CASIA-IVA-Lab/DPT

extract the patch by calling MSDeformAttnFunction function

RebornForPower opened this issue · 3 comments

Thanks for sharing your work~
I find in file depatch_embed.py (L112) will call this line:

output = MSDeformAttnFunction.apply(x, value_spatial_shapes, self.value_level_start_index, sampling_locations, attention_weights, 1)
I assume that this code may produce extra deformable attention calculation. If it is true, this will produce extra computations and it is fair to compare to the PVT ?

Thanks for your attention here.

  1. DPT produces extra computations. However, we only add one DePatch module every stage. This addition is minor compared with the whole PVT model, especially those huge models. (3 linear layers compared with tens of attention modules). Please refer to our paper (Table 1) for detailed FLOPs comparison.
  2. In this code base, DPT makes use of the existing MSDeformAttn package in Deformable-DETR to only realize the function of arbitrary sampling. However, we do not need to compute deformable attention in fact. As in depatch_embed.py (L110), we set attention map as an all-one matrix. This calculation is redundant, and can be eliminated if a more proper package for bilinear sampling is available.
  3. To be more detailed, the real additional computation lies in the prediction of sampling offsets, and denser sampling points (if more than 2x2 points are used here).
    • The prediction of sampling offsets involves two linear layers with dimension $C\times C$ and $C\times4$
    • 3x3 sampling points instead of 4 patches in original PVT increase the input dimension of patch embedding layer from $4\cdot C/2\times C$ to $9\cdot C/2\times C$
  4. Generally speaking, it increase computation comparable to 2~3 linear layers. It is much lighter than those attention modules. So, don't worry about it.

ok, i understand, Thanks for your patience replay~

Hi @volgachen,
Thank you for your explaination.
As you explained, the output of MSDeformAttnFunction.apply is the sampled features for each box?

output = MSDeformAttnFunction.apply(x, self.value_spatial_shapes, self.value_level_start_index, sampling_locations, attention_weights, 1)