Compatibility of TileDataset with Dataloder
Closed this issue · 3 comments
The following code
ds = PadoDataset('../../../../pathdrive-pado-tggates/local_dataset/', 'r')
ds = ds.filter(filter_local)
tile_ds = TileDataset(
ds,
tiling_strategy=FastGridTiling(
tile_size=(512, 512),
target_mpp=MPP(1, 1),
overlap=0,
min_chunk_size=0.5,
normalize_chunk_sizes=True,
)
)
dl = DataLoader(tile_ds, batch_size=2)
for batch in dl:
print(batch)
break
gives the following error
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>
I assume it's because some of the metadata is None
ImageId('acetamide', 'Kidney', '65440.svs', sit... 713 1 1 acetamide Kidney ftp://ftp.biosciencedbc.jp/archive/open-tggate... 1/1 Rat Repeat ... 345600.0 0.0 mg/kg None None None None None None
Iterating directly over the TileDataset works fine though, the issue seems to come from when I feed the TileDataset to the Dataloader.
If I pass explicitly the collate_fn defined in TileDataset
it = zip(PadoTileItem._fields, map(list, zip(*batch)))
# noinspection PyArgumentList
dct = CollatedPadoTileItems(it) # type: ignore
tile = dct["tile"]
# collate tiles
if tile:
if isinstance(tile[0], np.ndarray):
dct["tile"] = np.stack(tile)
else:
dct["tile"] = stack(tile)
return dct
as follows
dl = DataLoader(tile_ds, batch_size=2, collate_fn=collate_fn)
everything works. I would expect DataLoader to call the internal collate_fn automatically?
This is by design. Check here to see how the dataset's collate_fn method is passed to the DataLoader constructor:
Lines 621 to 637 in 635d7b8
We could check if since this was written, pytorch made changes that allow a dataset to provide these methods via inspection.
Works as suggested!