Compatibility of TileDataset with Dataloder
Closed this issue · 3 comments
The following code
ds = PadoDataset('../../../../pathdrive-pado-tggates/local_dataset/', 'r')
ds = ds.filter(filter_local)
tile_ds = TileDataset(
tile_size=(512, 512),
target_mpp=MPP(1, 1),
dl = DataLoader(tile_ds, batch_size=2)
for batch in dl:
gives the following error
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>
I assume it's because some of the metadata is None
ImageId('acetamide', 'Kidney', '65440.svs', sit... 713 1 1 acetamide Kidney 1/1 Rat Repeat ... 345600.0 0.0 mg/kg None None None None None None
Iterating directly over the TileDataset works fine though, the issue seems to come from when I feed the TileDataset to the Dataloader.
If I pass explicitly the collate_fn defined in TileDataset
it = zip(PadoTileItem._fields, map(list, zip(*batch)))
# noinspection PyArgumentList
dct = CollatedPadoTileItems(it) # type: ignore
tile = dct["tile"]
# collate tiles
if tile:
if isinstance(tile[0], np.ndarray):
dct["tile"] = np.stack(tile)
dct["tile"] = stack(tile)
return dct
as follows
dl = DataLoader(tile_ds, batch_size=2, collate_fn=collate_fn)
everything works. I would expect DataLoader to call the internal collate_fn automatically?
This is by design. Check here to see how the dataset's collate_fn method is passed to the DataLoader constructor:
Lines 621 to 637 in 635d7b8
We could check if since this was written, pytorch made changes that allow a dataset to provide these methods via inspection.
Works as suggested!