RatInABox-Lab/RatInABox

Unnoticed, incorrectly typed out parameter names

colleenjg opened this issue · 3 comments

Hey @TomGeorge1234 , I was just working with ratinabox and decided to implement a small functionality for myself. I wanted to mention it here in case you'd like to add it / anyone else could use it.

So, as we know, the downside to allowing users to specify a wide range of parameters using a params dictionary when initializing a new object (Agent, Environment, Neurons) is that incorrectly typed out parameter names can slip through easily. For example, you might be initializing a GridCells object, accidentally using the key fr_max instead of max_fr in your parameters dictionary, and not notice for a while that you failed to set max_fr as intended.

So, I wrote a function that checks whether objects have unexpected attributes, and raises a warning if they do.

def check_attributes(Obj, check_attrs=None):
    """Checks that the attributes of an object are expected, based on a 
    default initialization. This is useful when passing a dictionary of
    parameters to set attributes, as it could flag incorrectly typed out 
    parameter names.
    
    Args:
        Obj (object): Object to check.
    
    Optional Args:
        check_attrs (dict, optional): Dictionary of attribute names to check. If None, 
        all attributes are checked. Defaults to None.
    """

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        argspec = inspect.getfullargspec(type(Obj).__init__)
    arg_names = argspec.args
    if arg_names[0] != "self":
        raise NotImplementedError("Expected the first argument to be 'self'.")
    required_args = [
        getattr(Obj, arg_name) 
        for arg_name in arg_names[1: -len(argspec.defaults)]
        ]
    
    kwargs = dict()
    if "check_attributes" in arg_names:
        kwargs["check_attributes"] = False

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        base_Obj = type(Obj)(*required_args, **kwargs)
    
    unexpected_attributes = [
        key for key in Obj.__dict__.keys() if key not in base_Obj.__dict__.keys()
    ]

    if check_attrs is not None:
        unexpected_attributes = [
            key for key in unexpected_attributes if key in check_attrs
        ]
    
    if len(unexpected_attributes):
        num = len(unexpected_attributes)
        unexpected_attributes = ", ".join(
            [f"'{attr}'" for attr in unexpected_attributes]
            )
        if hasattr(Obj, "name"):
            object_name = Obj.name
        else:
            object_name = str(Obj)
        warnings.warn(
            f"Found {num} unexpected attribute(s) for {object_name}: "
            f"{unexpected_attributes}"
            )

If called in the init of a class, as follows, the function will flag with a warning any keys in params that would not be attributes of a default object of that class. Specifically, using GridCells as an example:

class GridCells(Neurons):
    ...
    def __init__(self, Agent, params={}, check_attributes=True): # added check_attributes argument
        default_params = {
            ...
        }
        self.Agent = Agent
        default_params.update(params)
        self.params = default_params
        super().__init__(Agent, self.params)
        if check_attributes: # check, if applicable
            utils.check_attributes(self, params.keys())

If you then call GridCells(Agent, params={"fr_max": 2}) you get the following warning:
UserWarning: Found 1 unexpected attribute(s) for GridCells: fr_max

So, this would be useful for any classes where (1) all parameters received are set as attributes, and (2) all acceptable parameters received have default values in the default_params dictionaries at some level of initialization. I believe most if not all of the classes that take a params dict as an input in ratinabox meet these two criteria.

Still, because this is a bit... non-pythonic... I should mention a few potential unintended consequences I can currently foresee

  1. an infinite loop, if calling this from the __init__(), which is avoided as long as the objects this is used with have a check_attributes keyword argument in their __init__(),
  2. spurious printing when the dummy object is created (the function suppresses warnings.warn calls, but can't suppress print calls, of course and doesn't currently suppress log calls),
  3. this seems unlikely, but I'll mention it in case: a spurious reference to the dummy object could be created while running check_attributes() if, for example, when an Object, like GridCells is initialized, circular references are created (i.e., not only does the GridCells object have a pointer to its Agent (self.Agent), but the Agent then also has a pointer added in return to any associated Neurons (e.g., self.Agent.Neurons.append(self)). As far as I can tell, this kind of circular referencing was not implemented in ratinabox, so I don't believe that my dummy object will create undue clutter in any associated objects.

This got a bit complicated..., but anyway, do let me know if you think it would be useful!

This might be overkill. I could also just read my dictionary keys more carefully... 😂

No I really like this! I always manage to screw up handing in params (it is very much the downside of using dictionaries) this would be a really nice catch-all , let me take a closer look tomorrow :)))

thanks!

As discussed offline: let's shift this to a system where the default_params dictionary lives just above the __init__() then some utils function performs a cascading check which concatenates all default_params for the object, its parents, its parent's parents...up until parent.default_param isn't an attribute and checks this against the passed params dict for inconsistencies. It raises a warning

The hidden benefit of this is that users can also check what default_params a class accepts without initialising a class (just typing GridCells.default_params should do the trick).

The other benefit is this won't make a load of dummy classes