huggingface/pytorch-image-models

[BUG] Incorrect string 'map ' with an extra ' '?

Closed this issue · 1 comments

sry2 commented

Describe the bug
In line 362 of ./timm/models/vision_transformer.py, the conditional statement shows

elif global_pool != 'map ' and self.attn_pool is not None:

where the expected string 'map ' seems to have an extra white space at the end of it. This totally changes the code logic of the whole conditional statement:

      if global_pool is not None:
            assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
            if global_pool == 'map' and self.attn_pool is None:
                assert False, "Cannot currently add attention pooling in reset_classifier()."
            elif global_pool != 'map ' and self.attn_pool is not None:
                self.attn_pool = None  # remove attention pooling
            self.global_pool = global_pool

which will set the value of the self.attn_pool to None again if the global_pool variant is set to 'map'.

To Reproduce
Steps to reproduce the behavior:

  1. Set the init param global_pool of class VisionTransformer to 'map' in instantiation.
  2. Call the member method reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) with the input global_pool='map' and arbitrary num_classes.
  3. The member self.attn_pool will be initiated normally as a MAP head in __init__() and set to None again after calling of reset_classifier() because it will pass the second conditional statement, which is incorrect and will cause error in following procedure.

Expected behavior
Since the set of global_pool='map' significantly demands the usage of attention pooling module, the extra calling of reset_classifier() with global_pool set to 'map' should not change this original behavior. The conditional statement meant to check and reset the setting of attention pooling but with this bug, the combination of global_pool=='map' and self.attn_pool is not None will certainly leads to entering the second condition in line 362.

Desktop (please complete the following information):

  • OS: Linux version 4.19.90-vhulk2211.3.0.h1543.eulerosv2r10.aarch64
  • This repository version: tag v1.0.9 from pypi package
  • PyTorch version: 2.1.0

Weird, I recall finding this extra space before and thought it was fixed but maybe I never pushed :/ tx