google/python-fire

Unexpected behaviour when using wrapped / decorated functions - cant supply arguments added

jamesowers-roo opened this issue · 2 comments

Here's a simple example - say we want to wrap a function to set a logging level. We want to use a wrapper to apply this to every function within a module and remove the boilerplate.

test_script.py:

import functools
import logging

import fire

logging.basicConfig()
LOGGER = logging.getLogger(__name__)

def set_logging_level(func):
    """
    A decorator that sets the logging level for the decorated function.
    The level should be a string (e.g. 'DEBUG', 'INFO', etc.).
    """

    @functools.wraps(func)
    def wrapper(*args, logging_level="INFO", **kwargs):
        LOGGER.info(f"Setting logging level to {logging_level}")
        LOGGER.setLevel(logging_level)
        return func(*args, **kwargs)

    return wrapper

@set_logging_level
def test_function(a=1):
    """Docstring"""
    LOGGER.info("inside test_function")
    LOGGER.debug(f"{a=}")

if __name__ == "__main__":
    fire.Fire(test_function)

The function test_function itself behaves as expected, i.e. it's correctly wrapped

>>> test_function.__doc__
Docstring

>>> test_function()
INFO:__main__:Setting logging level to INFO
INFO:__main__:inside test_function

>>> test_function(logging_level="DEBUG")
INFO:__main__:Setting logging level to DEBUG
INFO:__main__:inside test_function
DEBUG:__main__:a=1

But calling with python-fire fails:

python test_script.py --logging-level DEBUG
INFO:__main__:inside test_function
ERROR: Could not consume arg: --logging-level
Usage: test_script.py -

For detailed information on this command, run:
  test_script.py - --help

Am I misunderstanding how python-fire should behave with wrapped functions?

I am very happy to provide a PR to fix this. If someone could help get me started and point to the right files to edit, I'd be very grateful.

This is not an issue with fire but rather one with changing a function's call signature with a decorator, and not updating the signature via the wrapper: fire correctly complains that logging_level isn't an argument of test_function.

To fix, we need to add code inside the decorator which updates the wrapped function's call signature, for example:

    wrapper_signature = inspect.signature(func)
    parameters = list(wrapper_signature.parameters.values())
    parameters.append(
        inspect.Parameter(
            "logging_level",
            inspect.Parameter.KEYWORD_ONLY,
            default=DEFAULT_LOGGING_LEVEL,
        )
    )
    wrapper_signature = wrapper_signature.replace(parameters=parameters)
    functools.update_wrapper(wrapper, func)
    wrapper.__signature__ = wrapper_signature

So a full solution which works as expected is:

import functools
import inspect
import logging

import fire

logging.basicConfig()
LOGGER = logging.getLogger(__name__)

def set_logging_level(func):
    """
    A decorator that sets the logging level for the decorated function.
    The level should be a string (e.g. 'DEBUG', 'INFO', etc.).
    """
    DEFAULT_LOGGING_LEVEL = "INFO"
    @functools.wraps(func)
    def wrapper(*args, logging_level=DEFAULT_LOGGING_LEVEL, **kwargs):
        LOGGER.info(f"Setting logging level to {logging_level}")
        LOGGER.setLevel(logging_level)
        return func(*args, **kwargs)

    wrapper_signature = inspect.signature(func)
    parameters = list(wrapper_signature.parameters.values())
    parameters.append(
        inspect.Parameter(
            "logging_level",
            inspect.Parameter.KEYWORD_ONLY,
            default=DEFAULT_LOGGING_LEVEL,
        )
    )
    wrapper_signature = wrapper_signature.replace(parameters=parameters)
    functools.update_wrapper(wrapper, func)
    wrapper.__signature__ = wrapper_signature

    return wrapper

@set_logging_level
def test_function(a=1):
    """Docstring"""
    LOGGER.info("inside test_function")
    LOGGER.debug(f"{a=}")

if __name__ == "__main__":
    fire.Fire(test_function)