ViCCo-Group/thingsvision

Check validity of module name

LukasMut opened this issue · 3 comments

We should add a function that checks the validity of module_name (i.e., is the user-specified module name a valid module that exists?) before feature extraction starts. Else, the feature extraction method will raise a KeyError after the first iteration which is probably fine but definitely not great. I am thinking about something along the lines of

def get_module_names(self) -> List[str]:
    if self.backend == 'pt':
        module_names, _ = zip(*self.model.named_modules())
        module_names = list(filter(lambda n: len(n) > 0, module_names))
    else:
        module_names = [l._name for l in self.model.submodules]
    return module_names

def _is_valid_module(self, module_name: str) -> bool:
    valid_names = self.get_module_names()
    return (module_name in valid_names)

which is supposed to be evaluated prior to executing the main feature extraction loop. Therefore, this should probably be evaluated in extractor.extract_features(...).

This does not seem to be necessary if we resolve issue #89.

I think its still necessary, if we continue allowing users to extract features for individual network modules (which I also think we should do)

@andropar I agree.