Lightning-AI/torchmetrics

MetricTracker use higher_is_better as default for maximize

Opened this issue ยท 2 comments

๐Ÿš€ Feature

Change the maximize argument for the MetricTracker wrapper from defaulting to True to using the higher_is_better property of the metric(s) if no maximize is supplied.

Motivation

Reduce boilerplate code. In 99% of use cases, you want to track the best metric rather than the worst. Most metrics you would want to track already have a higher_is_better property, so manually typing maximize is just boilerplate. Additionally, the current approach of defaulting to True is kinda arbitrary, as it's roughly 50/50 whether you want to maximize or minimize a random metric.

Pitch

The amount of code changes needed is minimal as it only affects the __init__ method, leaving the rest of the implementation unchanged.

Here's a suggested implementation:

- def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None:
+ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = None) -> None:
    super().__init__()
    if not isinstance(metric, (Metric, MetricCollection)):
        raise TypeError(
            "Metric arg needs to be an instance of a torchmetrics"
            f" `Metric` or `MetricCollection` but got {metric}"
        )
    self._base_metric = metric

+   if maximize is None:
+       if isinstance(metric, Metric):
+           if not hasattr(metric, 'higher_is_better'):
+               raise AttributeError(
+                   f"The metric '{metric.__class__.__name__}' does not have a 'higher_is_better' attribute. "
+                   "Please provide the `maximize` argument explicitly."
+               )
+           self.maximize = metric.higher_is_better
+       elif isinstance(metric, MetricCollection):
+           self.maximize = []
+           for name, m in metric.items():
+               if not hasattr(m, 'higher_is_better'):
+                   raise AttributeError(
+                       f"The metric '{name}' in the MetricCollection does not have a 'higher_is_better' attribute. "
+                       "Please provide the `maximize` argument explicitly."
+                   )
+               self.maximize.append(m.higher_is_better)
+   else:
        if not isinstance(maximize, (bool, list)):
            raise ValueError("Argument `maximize` should either be a single bool or list of bool")
        if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
            raise ValueError("The length of argument `maximize` should match the length of the metric collection")
        if isinstance(metric, Metric) and not isinstance(maximize, bool):
            raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric")
        self.maximize = maximize

    self._increment_called = False

Additional context

I haven't made many pull requests before, so i'm very open to suggestions and changes :)

Hi! thanks for your contribution!, great first issue!

Hi @MoustHolmes, thanks for the code suggestion :)
I like the idea but my only worry is trying to keep it backwards compatible e.g. if some users are just using the current default of maximize=True their code would break in future update of torchmetrics.
I am therefore fine with this change but either the default needs to stay at True or we at least need to have a deprecation phase.