Check validity of module name
LukasMut opened this issue · 3 comments
LukasMut commented
We should add a function that checks the validity of module_name
(i.e., is the user-specified module name a valid module that exists?) before feature extraction starts. Else, the feature extraction method will raise a KeyError
after the first iteration which is probably fine but definitely not great. I am thinking about something along the lines of
def get_module_names(self) -> List[str]:
if self.backend == 'pt':
module_names, _ = zip(*self.model.named_modules())
module_names = list(filter(lambda n: len(n) > 0, module_names))
else:
module_names = [l._name for l in self.model.submodules]
return module_names
def _is_valid_module(self, module_name: str) -> bool:
valid_names = self.get_module_names()
return (module_name in valid_names)
which is supposed to be evaluated prior to executing the main feature extraction loop. Therefore, this should probably be evaluated in extractor.extract_features(...)
.
andropar commented
I think its still necessary, if we continue allowing users to extract features for individual network modules (which I also think we should do)