can not test svit-adapter-t-0.5x-ftune.py because self.num_heads is not even
Livioni opened this issue · 1 comments
Great work, a milestone for bringing token pruning into dense predictions.
I found that the svit-adapter-t-0.5x-ftune.py can not be tested because self.num_heads is not even.
In the InteractionBlockWithSelection class within adapter_modules.py, when x.shape[0] != 1 (i.e., the evaluation batch size > 1), x is reshaped into a nested_tensor and passed into blk (which is TransformerEncoderLayer).
def forward(self, x, c, indexes, deform_inputs1, deform_inputs2, shape, blks, selective_modules, keep_ratio):
n_skip = 3
x = self.injector(query=x, reference_points=deform_inputs1[0],
feat=c, spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2])
layer_ratio_loss = 0.
has_loss = 0
for i in range(indexes[0], indexes[-1] + 1):
if i < n_skip:
x = blks[i](x)
else:
if self.training:
selector, diff_selector = selective_modules[i - n_skip](x)
x = diff_selector * blks[i](x, src_key_padding_mask=~selector) + \
(1 - diff_selector) * x
layer_ratio_loss += self._ratio_loss(diff_selector, keep_ratio[i - n_skip])
has_loss += 1
else:
if x.shape[0] == 1:
selector, _ = selective_modules[i - n_skip](x)
real_indices = torch.argsort(selector.int(), dim=1, descending=True)\
[:, :selector.sum(1)].unsqueeze(-1).expand(-1, -1, x.shape[-1])
selected_x = torch.gather(x, 1, real_indices)
selected_x = blks[i](selected_x)
x.scatter_(1, real_indices, selected_x)
else:
selector, diff_selector = selective_modules[i - n_skip](x)
l_aligned_x, l_aligned_mask = left_align_tokens2(x, selector)
nt_x = torch._nested_tensor_from_mask(l_aligned_x, l_aligned_mask, mask_check=False)
nt_x = blks[i](nt_x, src_key_padding_mask=None)
x.masked_scatter_(selector.unsqueeze(-1), torch.cat(nt_x.unbind(), 0))
c = self.extractor(query=c, reference_points=deform_inputs2[0],
feat=x, spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2], shape=shape)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(query=c, reference_points=deform_inputs2[0],
feat=x, spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2], shape=shape)
return x, c, layer_ratio_loss, has_loss
However, it seems that MultiheadAttention in torch.nn does not support computations when self.num_heads is set to 3. How can I resolve this issue?
发生异常: AssertionError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because self.num_heads is not even
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1212, in forward
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/base/TransformerEncoderLayer.py", line 250, in _sa_block
x = self.self_attn(x, x, x,
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/base/TransformerEncoderLayer.py", line 239, in forward
x = x + self.drop_path1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/base/selective_vit.py", line 91, in forward
return self.TransformerEncoderLayer(x, src_key_padding_mask=src_key_padding_mask)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/adapter_modules.py", line 251, in forward
nt_x = blks[i](nt_x, src_key_padding_mask=None)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/backbones/selective_vit_adapter.py", line 135, in forward
x, c, layer_ratio_loss, has_loss = layer(x, c, indexes,
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/detectors/gumbel_two_stage.py", line 18, in extract_feat
out = self.backbone(img, need_loss)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/detectors/two_stage.py", line 227, in predict
x = self.extract_feat(batch_inputs)
File "/home/livion/Documents/github/source/ViT_Adapter/mmdet/models/detectors/base.py", line 94, in forward
return self.predict(inputs, data_samples)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 346, in _run_forward
results = self(**data, mode=mode)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 145, in test_step
return self._run_forward(data, mode='predict') # type: ignore
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/runner/loops.py", line 454, in run_iter
outputs = self.runner.model.test_step(data_batch)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/runner/loops.py", line 435, in run
self.run_iter(idx, data_batch)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1823, in test
metrics = self.test_loop.run() # type: ignore
File "/home/livion/Documents/github/source/ViT_Adapter/test.py", line 145, in main
runner.test()
File "/home/livion/Documents/github/source/ViT_Adapter/test.py", line 149, in <module>
main()
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/livion/miniconda3/envs/mmdet/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
AssertionError: MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because self.num_heads is not even
Hi @Livioni
Thanks for your interest. Since the batching relies on nested_tensor in pytorch BetterTransformer which requires even number of heads, tiny models with 3 heads cannot have batched inference. One workaround is using bigger models as they often have even number of heads, as shown in appendix B. If you need to use batch size > 1 for tiny models, you can change num_heads to 4 when finetuning.