pytorch/data

StatefulDataLoader stores worker state twice if the IterableDataset is also an Iterator

Closed this issue ยท 0 comments

๐Ÿ› Describe the bug

``
class MyIterabledataset(torch.utils.data.IterableDataset, Iterator, Stateful):

def __init__(self, samples):
    self.samples = samples
    self.size = len(self.samples)
    self.i = 0

def __iter__(self):
    return self

def __next__(self):
    if self.i >= len(self.samples):
        raise StopIteration
    else:
        i = self.i
    sample = self.samples[i]
    self.i += 1
    return sample

def state_dict(self):
    return {"i": self.i}


def load_state_dict(self, state_dict):
    self.i = state_dict["i"]

``

In the above example, the state will be stored in dataset_state (

dataset_state = try_to_serialize(dataset)
) as well as fetcher_state (
_DATASET_ITER_STATE: try_to_serialize(fetcher.dataset_iter), # type: ignore[union-attr]
).

State can be quite expensive to store and transfer and thus it would be good to avoid replicating it in this scenario.

Versions

main branch