Bayer-Group/pado

Compatibility of TileDataset with Dataloder

Closed this issue · 3 comments

The following code

ds = PadoDataset('../../../../pathdrive-pado-tggates/local_dataset/', 'r')
ds = ds.filter(filter_local)

  tile_ds = TileDataset(
      ds,
      tiling_strategy=FastGridTiling(
          tile_size=(512, 512),
          target_mpp=MPP(1, 1),
          overlap=0,
          min_chunk_size=0.5,
          normalize_chunk_sizes=True,
      )
  )

  dl = DataLoader(tile_ds, batch_size=2)
  for batch in dl:
      print(batch)
      break

gives the following error

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>
I assume it's because some of the metadata is None

ImageId('acetamide', 'Kidney', '65440.svs', sit... 713 1 1 acetamide Kidney ftp://ftp.biosciencedbc.jp/archive/open-tggate... 1/1 Rat Repeat ... 345600.0 0.0 mg/kg None None None None None None

Iterating directly over the TileDataset works fine though, the issue seems to come from when I feed the TileDataset to the Dataloader.

If I pass explicitly the collate_fn defined in TileDataset

    it = zip(PadoTileItem._fields, map(list, zip(*batch)))
    # noinspection PyArgumentList
    dct = CollatedPadoTileItems(it)  # type: ignore
    tile = dct["tile"]
    # collate tiles
    if tile:
        if isinstance(tile[0], np.ndarray):
            dct["tile"] = np.stack(tile)
        else:
            dct["tile"] = stack(tile)
    return dct

as follows

dl = DataLoader(tile_ds, batch_size=2, collate_fn=collate_fn)

everything works. I would expect DataLoader to call the internal collate_fn automatically?

ap-- commented

This is by design. Check here to see how the dataset's collate_fn method is passed to the DataLoader constructor:

pado/pado/itertools.py

Lines 621 to 637 in 635d7b8

loader = DataLoader(
dataset,
batch_size=10,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=3,
collate_fn=dataset.collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
prefetch_factor=2,
persistent_workers=False,
)

We could check if since this was written, pytorch made changes that allow a dataset to provide these methods via inspection.

Works as suggested!