MetricTracker use higher_is_better as default for maximize
Opened this issue ยท 2 comments
๐ Feature
Change the maximize
argument for the MetricTracker
wrapper from defaulting to True
to using the higher_is_better
property of the metric(s) if no maximize
is supplied.
Motivation
Reduce boilerplate code. In 99% of use cases, you want to track the best metric rather than the worst. Most metrics you would want to track already have a higher_is_better
property, so manually typing maximize
is just boilerplate. Additionally, the current approach of defaulting to True
is kinda arbitrary, as it's roughly 50/50 whether you want to maximize or minimize a random metric.
Pitch
The amount of code changes needed is minimal as it only affects the __init__
method, leaving the rest of the implementation unchanged.
Here's a suggested implementation:
- def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = True) -> None:
+ def __init__(self, metric: Union[Metric, MetricCollection], maximize: Union[bool, List[bool]] = None) -> None:
super().__init__()
if not isinstance(metric, (Metric, MetricCollection)):
raise TypeError(
"Metric arg needs to be an instance of a torchmetrics"
f" `Metric` or `MetricCollection` but got {metric}"
)
self._base_metric = metric
+ if maximize is None:
+ if isinstance(metric, Metric):
+ if not hasattr(metric, 'higher_is_better'):
+ raise AttributeError(
+ f"The metric '{metric.__class__.__name__}' does not have a 'higher_is_better' attribute. "
+ "Please provide the `maximize` argument explicitly."
+ )
+ self.maximize = metric.higher_is_better
+ elif isinstance(metric, MetricCollection):
+ self.maximize = []
+ for name, m in metric.items():
+ if not hasattr(m, 'higher_is_better'):
+ raise AttributeError(
+ f"The metric '{name}' in the MetricCollection does not have a 'higher_is_better' attribute. "
+ "Please provide the `maximize` argument explicitly."
+ )
+ self.maximize.append(m.higher_is_better)
+ else:
if not isinstance(maximize, (bool, list)):
raise ValueError("Argument `maximize` should either be a single bool or list of bool")
if isinstance(maximize, list) and isinstance(metric, MetricCollection) and len(maximize) != len(metric):
raise ValueError("The length of argument `maximize` should match the length of the metric collection")
if isinstance(metric, Metric) and not isinstance(maximize, bool):
raise ValueError("Argument `maximize` should be a single bool when `metric` is a single Metric")
self.maximize = maximize
self._increment_called = False
Additional context
I haven't made many pull requests before, so i'm very open to suggestions and changes :)
Hi! thanks for your contribution!, great first issue!
Hi @MoustHolmes, thanks for the code suggestion :)
I like the idea but my only worry is trying to keep it backwards compatible e.g. if some users are just using the current default of maximize=True
their code would break in future update of torchmetrics.
I am therefore fine with this change but either the default needs to stay at True
or we at least need to have a deprecation phase.