ViCCo-Group/thingsvision

Torchvision compatibility

Closed this issue · 6 comments

Enable compatibility with most recent torchvision version (>= 0.13.0)

I am thinking about either a dictionary, a named tuple, or a (custom) class that maps the names of the models (usually lower-case) to the names of the model weights (mostly a mix between upper-case and capitalization; see here: https://pytorch.org/vision/stable/models.html). We want to avoid a bunch of if-statements to obtain the corresponding model weights. Possibly appropriate thing to do is something along the lines of model_weights=getattr(torchvision.models,f'{model2weights[model_name]}').DEFAULT, where model2weights is a dictionary or something else that maps model names to weight names which are, unfortunately, different.

I thought about that too, but then I saw that using strings is also supported, so if we only care about the default weights, we could create our models with model(weights='DEFAULT') instead, which might be a bit easier?

I agree that this is the easier thing to do but there is a problem with this approach. Each set of weights has its own image transformations which can only be accessed through the weight objects themselves. We would not be able to access the transformations by using the string approach. See example below.

# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)

Ah, I see, you're right. Then what about something like this:

import inspect

weights_name = [
    member for member, _ in inspect.getmembers(torchvision.models) 
    if (member.lower.startswith(model_name) and 'Weights' in member)
][0]
 
# do some error handling in case nothing was found

weights = getattr(torchvision.models, f'{weights_name}').DEFAULT

Writing everything into a dict just seems like too much work, but we would have to think about whether this could produce any issues later ...

Yes, that's actually a good idea! I like it. If member is a string, we could slightly improve the list comprehension by doing

import re
import torchvision

weights_name = [
    m for m in dir(torchvision.models) 
    if re.search(f'(?=^{model_name})(?=.*weights$)', m.lower())
][0]

# do some error handling in case nothing was found

weights = getattr(torchvision.models, f'{weights_name}').DEFAULT

Probably need to think about whether a list comprehension is the right thing to do here. We seem to make the assumption that the final list has a single member which basically boils down to a string variable. So something like this does the trick I guess

import re
import torchvision

def get_weights_name(model_name: str) -> str:
    weights_name = None
    regex = f'(?=^{model_name})(?=.*weights$)'
    for m in dir(torchvision.models):
        if re.search(regex, m.lower()):
            weights_name = m
            break
    if not weights_name:
        raise ValueError(
            f'\nCould not find weights for {model_name} in <torchvision>. Choose a different model or change source.\n'
        )
    return weights_name


weights_name = get_weights_name(mode_name)
weights = getattr(torchvision.models, f'{weights_name}').DEFAULT

Enabled. Note that a user has only access to DEFAULT weights with the current approach. We need to make this more flexible by defining extractor classes for each source.