Torchvision compatibility
Closed this issue · 6 comments
Enable compatibility with most recent torchvision version (>= 0.13.0)
I am thinking about either a dictionary, a named tuple, or a (custom) class that maps the names of the models (usually lower-case) to the names of the model weights (mostly a mix between upper-case and capitalization; see here: https://pytorch.org/vision/stable/models.html). We want to avoid a bunch of if-statements to obtain the corresponding model weights. Possibly appropriate thing to do is something along the lines of model_weights=getattr(torchvision.models,f'{model2weights[model_name]}').DEFAULT
, where model2weights
is a dictionary or something else that maps model names to weight names which are, unfortunately, different.
I thought about that too, but then I saw that using strings is also supported, so if we only care about the default weights, we could create our models with model(weights='DEFAULT')
instead, which might be a bit easier?
I agree that this is the easier thing to do but there is a problem with this approach. Each set of weights has its own image transformations which can only be accessed through the weight objects themselves. We would not be able to access the transformations by using the string approach. See example below.
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
# Apply it to the input image
img_transformed = preprocess(img)
Ah, I see, you're right. Then what about something like this:
import inspect
weights_name = [
member for member, _ in inspect.getmembers(torchvision.models)
if (member.lower.startswith(model_name) and 'Weights' in member)
][0]
# do some error handling in case nothing was found
weights = getattr(torchvision.models, f'{weights_name}').DEFAULT
Writing everything into a dict just seems like too much work, but we would have to think about whether this could produce any issues later ...
Yes, that's actually a good idea! I like it. If member
is a string, we could slightly improve the list comprehension by doing
import re
import torchvision
weights_name = [
m for m in dir(torchvision.models)
if re.search(f'(?=^{model_name})(?=.*weights$)', m.lower())
][0]
# do some error handling in case nothing was found
weights = getattr(torchvision.models, f'{weights_name}').DEFAULT
Probably need to think about whether a list comprehension is the right thing to do here. We seem to make the assumption that the final list has a single member which basically boils down to a string variable. So something like this does the trick I guess
import re
import torchvision
def get_weights_name(model_name: str) -> str:
weights_name = None
regex = f'(?=^{model_name})(?=.*weights$)'
for m in dir(torchvision.models):
if re.search(regex, m.lower()):
weights_name = m
break
if not weights_name:
raise ValueError(
f'\nCould not find weights for {model_name} in <torchvision>. Choose a different model or change source.\n'
)
return weights_name
weights_name = get_weights_name(mode_name)
weights = getattr(torchvision.models, f'{weights_name}').DEFAULT
Enabled. Note that a user has only access to DEFAULT weights with the current approach. We need to make this more flexible by defining extractor classes for each source.