Object detection example broken
lorinczszabolcs opened this issue ยท 6 comments
๐ Bug - Object detection example broken
The object detection example here is broken. I think the actual problem lies in package version incompatibilities. I installed the latest versions, so I assumed it would work out of the box.
To Reproduce
Used the original script that is shared here, also tried with own custom inputs, but the same error pops up:
/lib/python3.8/site-packages/effdet/anchors.py", line 404, in batch_label_anchors
box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Code sample
Just for the sake of having everythin here, copying the original code:
import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector
# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
transform_kwargs={"image_size": 512},
batch_size=4,
)
# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=512)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Detect objects in a few images!
datamodule = ObjectDetectionData.from_files(
predict_files=[
"data/coco128/images/train2017/000000000625.jpg",
"data/coco128/images/train2017/000000000626.jpg",
"data/coco128/images/train2017/000000000629.jpg",
],
transform_kwargs={"image_size": 512},
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")
Expected behavior
The training should not fail with the mentioned error.
Environment
- OS (e.g., Linux):
Ubuntu 20.04.5 LTS
- Python version:
Python 3.8.10
- PyTorch/Lightning/Flash Version (e.g., 1.10/1.5/0.7):
torch==1.13.1
,pytorch-lightning==1.9.4
,lightning-flash==0.8.1.post0
- GPU models and configuration:
NVIDIA GeForce RTX 3090, CUDA Version: 11.6
- Any other relevant information: -
@lorinczszabolcs would you be interested in debugging it further and eventually sending a fix? ๐ฆฉ
I retried it now, and now a different error is given:
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-2-430f4efffa8e>](https://localhost:8080/#) in <cell line: 1>()
----> 1 import flash
2 from flash.core.data.utils import download_data
3 from flash.image import ObjectDetectionData, ObjectDetector
4
5 # 1. Create the DataModule
2 frames
[/usr/local/lib/python3.10/dist-packages/flash/core/data/utils.py](https://localhost:8080/#) in <module>
20 import requests
21 import urllib3
---> 22 from pytorch_lightning.utilities.apply_func import apply_to_collection
23 from torch import nn
24 from tqdm.auto import tqdm as tq
ModuleNotFoundError: No module named 'pytorch_lightning.utilities.apply_func'
Unfortunately I won't have the time to look into it in detail, but it still seems like it is some package version related issue (possibly caused by pytorch-lightning 2.0.0 (lightning 2.0.0) release.
possibly caused by pytorch-lightning 2.0.0 (lightning 2.0.0) release
Flash has pin dependency bellow 2.0
Upon installing with pip install lightning-flash
in a clean environment, it shows me that pytorch-lightning==2.0.2
gets installed. The following message is shown meanwhile: Collecting pytorch-lightning>=1.3.6 (from lightning-flash)
, indicating that the pinned dependency is >=1.3.6. even though I can see that is not the case here:
lightning-flash/requirements.txt
Line 8 in 14c2755
Any idea why that's happening?
Upon installing with
pip install lightning-flash
in a clean environment, it shows me thatpytorch-lightning==2.0.2
most likely this pin adjustment was not yet released, so pls install ut from source for now:
pip install https://github.com/Lightning-Universe/lightning-flash/archive/refs/heads/master.zip
shall be fixed in https://github.com/Lightning-Universe/lightning-flash/releases/tag/0.8.2