Lightning-Universe/lightning-flash

Support ClasswiseWrapper Metrics for Classification Tasks

newzealandpaul opened this issue ยท 0 comments

๐Ÿš€ Feature

Currently torchmetrics ClasswiseWrapper, which allows for per-class metrics, is not supported by Lightning.

Motivation

Per-class metrics are essential for many classification tasks, to give insight into model performance.

Pitch

Currently passing ClasswiseWrapper() metrics when creating a new instance of a Lightning model causes an error in flash/core/model.py:373 because ClasswiseWrapper objects do not have a _forward_cache attribute. Fixing that, causes an error in trainer/connectors/logger_connector/result.py:548 as it expects a tensor not a dict of tensors.

Users would expect that torchmetric features are natively supported.