Lightning-Universe/lightning-flash

Creation of an `ImageClassifier` fails with current torchmetrics version

davidefiocco opened this issue ยท 1 comments

๐Ÿ› Bug

I can't create an ImageClassifier with the current version of lightning-flash.

To Reproduce

The snippet

from flash.image import ImageClassifier
ImageClassifier(backbone="resnet18", num_classes = 2)

fails with:


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

100%
44.7M/44.7M [00:00<00:00, 141MB/s]

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[<ipython-input-3-0f6f4ac60023>](https://localhost:8080/#) in <module>
      1 from flash.image import ImageClassifier
----> 2 ImageClassifier(backbone="resnet18", num_classes = 2)

[/usr/local/lib/python3.8/dist-packages/flash/core/classification.py](https://localhost:8080/#) in _build(self, num_classes, labels, loss_fn, metrics, multi_label)
     65 
     66         if metrics is None:
---> 67             metrics = F1Score(num_classes) if (multi_label and num_classes) else Accuracy()
     68 
     69         if loss_fn is None:

TypeError: __new__() missing 1 required positional argument: 'task'

when using (current) torchmetrics==0.11.0 likely due to changes introduced in https://torchmetrics.readthedocs.io/en/v0.11.0/classification/accuracy.html

Expected behavior

The code should run flawlessly. Downgrading torchmetrics to 0.10.3 is a workaround for this issue (version of torchmetrics is not pinned in https://github.com/Lightning-AI/lightning-flash/blob/67bacaabaa473b9cf41952232e72e0d20d65e05c/requirements.txt#L5).

Environment

  • OS (e.g., Linux): Linux
  • Python version: 3.9
  • PyTorch/Lightning/Flash Version): 1.12.1+cu113/1.8.3.post1/0.8.1
  • Any other relevant information: torchmetrics==0.11.0
Borda commented

I am suggesting to pin the TM version, as your problem with saved models would need to eventually be patched on TM side, cc: @SkafteNicki @Lightning-AI/core-metrics