StatefulDataLoader stores worker state twice if the IterableDataset is also an Iterator
Closed this issue ยท 0 comments
gokulavasan commented
๐ Describe the bug
``
class MyIterabledataset(torch.utils.data.IterableDataset, Iterator, Stateful):
def __init__(self, samples):
self.samples = samples
self.size = len(self.samples)
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i >= len(self.samples):
raise StopIteration
else:
i = self.i
sample = self.samples[i]
self.i += 1
return sample
def state_dict(self):
return {"i": self.i}
def load_state_dict(self, state_dict):
self.i = state_dict["i"]
``
In the above example, the state will be stored in dataset_state (
data/torchdata/stateful_dataloader/worker.py
Line 219 in a0412de
data/torchdata/stateful_dataloader/worker.py
Line 213 in a0412de
State can be quite expensive to store and transfer and thus it would be good to avoid replicating it in this scenario.
Versions
main branch