Unnoticed, incorrectly typed out parameter names
colleenjg opened this issue · 3 comments
Hey @TomGeorge1234 , I was just working with ratinabox
and decided to implement a small functionality for myself. I wanted to mention it here in case you'd like to add it / anyone else could use it.
So, as we know, the downside to allowing users to specify a wide range of parameters using a params
dictionary when initializing a new object (Agent, Environment, Neurons) is that incorrectly typed out parameter names can slip through easily. For example, you might be initializing a GridCells
object, accidentally using the key fr_max
instead of max_fr
in your parameters dictionary, and not notice for a while that you failed to set max_fr
as intended.
So, I wrote a function that checks whether objects have unexpected attributes, and raises a warning if they do.
def check_attributes(Obj, check_attrs=None):
"""Checks that the attributes of an object are expected, based on a
default initialization. This is useful when passing a dictionary of
parameters to set attributes, as it could flag incorrectly typed out
parameter names.
Args:
Obj (object): Object to check.
Optional Args:
check_attrs (dict, optional): Dictionary of attribute names to check. If None,
all attributes are checked. Defaults to None.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
argspec = inspect.getfullargspec(type(Obj).__init__)
arg_names = argspec.args
if arg_names[0] != "self":
raise NotImplementedError("Expected the first argument to be 'self'.")
required_args = [
getattr(Obj, arg_name)
for arg_name in arg_names[1: -len(argspec.defaults)]
]
kwargs = dict()
if "check_attributes" in arg_names:
kwargs["check_attributes"] = False
with warnings.catch_warnings():
warnings.simplefilter("ignore")
base_Obj = type(Obj)(*required_args, **kwargs)
unexpected_attributes = [
key for key in Obj.__dict__.keys() if key not in base_Obj.__dict__.keys()
]
if check_attrs is not None:
unexpected_attributes = [
key for key in unexpected_attributes if key in check_attrs
]
if len(unexpected_attributes):
num = len(unexpected_attributes)
unexpected_attributes = ", ".join(
[f"'{attr}'" for attr in unexpected_attributes]
)
if hasattr(Obj, "name"):
object_name = Obj.name
else:
object_name = str(Obj)
warnings.warn(
f"Found {num} unexpected attribute(s) for {object_name}: "
f"{unexpected_attributes}"
)
If called in the init of a class, as follows, the function will flag with a warning any keys in params
that would not be attributes of a default object of that class. Specifically, using GridCells
as an example:
class GridCells(Neurons):
...
def __init__(self, Agent, params={}, check_attributes=True): # added check_attributes argument
default_params = {
...
}
self.Agent = Agent
default_params.update(params)
self.params = default_params
super().__init__(Agent, self.params)
if check_attributes: # check, if applicable
utils.check_attributes(self, params.keys())
If you then call GridCells(Agent, params={"fr_max": 2})
you get the following warning:
UserWarning: Found 1 unexpected attribute(s) for GridCells: fr_max
So, this would be useful for any classes where (1) all parameters received are set as attributes, and (2) all acceptable parameters received have default values in the default_params
dictionaries at some level of initialization. I believe most if not all of the classes that take a params
dict as an input in ratinabox
meet these two criteria.
Still, because this is a bit... non-pythonic... I should mention a few potential unintended consequences I can currently foresee
- an infinite loop, if calling this from the
__init__()
, which is avoided as long as the objects this is used with have acheck_attributes
keyword argument in their__init__()
, - spurious printing when the dummy object is created (the function suppresses
warnings.warn
calls, but can't suppress print calls, of course and doesn't currently suppress log calls), - this seems unlikely, but I'll mention it in case: a spurious reference to the dummy object could be created while running
check_attributes()
if, for example, when an Object, likeGridCells
is initialized, circular references are created (i.e., not only does theGridCells
object have a pointer to itsAgent
(self.Agent
), but theAgent
then also has a pointer added in return to any associated Neurons (e.g.,self.Agent.Neurons.append(self)
). As far as I can tell, this kind of circular referencing was not implemented inratinabox
, so I don't believe that my dummy object will create undue clutter in any associated objects.
This got a bit complicated..., but anyway, do let me know if you think it would be useful!
This might be overkill. I could also just read my dictionary keys more carefully... 😂
No I really like this! I always manage to screw up handing in params (it is very much the downside of using dictionaries) this would be a really nice catch-all , let me take a closer look tomorrow :)))
thanks!
As discussed offline: let's shift this to a system where the default_params
dictionary lives just above the __init__()
then some utils
function performs a cascading check which concatenates all default_params
for the object, its parents, its parent's parents...up until parent.default_param
isn't an attribute and checks this against the passed params
dict for inconsistencies. It raises a warning
The hidden benefit of this is that users can also check what default_params
a class accepts without initialising a class (just typing GridCells.default_params
should do the trick).
The other benefit is this won't make a load of dummy classes