iterator_utils.sample_iterators() hangs indefinitely.
Spandan-Madan opened this issue · 2 comments
I added a bunch of debug prints to see why MultiTFRecordDataset fails for tfrecords >=4. It seems when iterating over the data loader, first a function call is made for reader.multi_tfrecord_loader() which in the end calls iterator_utils.sample_iterators().
I added print statements inside sample_iterators() and this part of the code is where the program gets stuck:
while True:
choice = np.random.choice(len(ratios), p=ratios)
print('made my choice, yielding')
yield next(iterators[choice])
This should be run as many times as the batch size (I used 10 to debug). When using only 2 tfrecord files, that's exactly what happens. Here's the output -
iter called
Length of loaders is 2
iter obtained, returning
iterators 2
Calling iteration
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
But when I try with 4 or more tfrecords, it gets stuck in between and the data loading never finishes. Here's the output-
iter called
Length of loaders is 4
iter obtained, returning
iterators 4
Calling iteration
made my choice, yielding
made my choice, yielding
made my choice, yielding
made my choice, yielding
The number of steps after which it hangs varies. It's random, sometimes hangs after 2 yield calls, at other times after 4 (as above) in my re-runs.
I have a feeling this has to do with how the sub-processes are being spawned. Any ideas why it's happening and how to fix it?
Thanks in advance!
For reference, here's the code I am using:-
import cv2
import numpy as np
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset
tfrecord_path = "/om/user/xboix/data/ImageNet/train-00277-of-01024"
def resize_image(features):
# get BGR image from bytes
img = cv2.imdecode(features["image/encoded"], -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
features["image/encoded"] = cv2.resize(img, (224,224), interpolation = cv2.INTER_AREA)
features["image/encoded"] = np.transpose(features["image/encoded"], (2, 0, 1))
return features
train_splits = {}
tfrecord_pattern='/om/user/xboix/data/ImageNet/{}'
for i in range(4):
key = "train-%05d-of-01024"%i
value = 1.0
train_splits[key] = value
description = {"image/encoded": "byte", "image/class/label": "int", "image/height": "int", "image/width": "int"}
train_dataset = MultiTFRecordDataset(tfrecord_pattern, None, train_splits, description, transform=resize_image)
train_loader_tfr = torch.utils.data.DataLoader(train_dataset, batch_size=10)
for data in train_loader_tfr:
break
I can't reproduce this problem. This synthetic example seems to work fine with any number of iterations:
import numpy as np
import tfrecord
import tfrecord.torch.dataset
import torch
splits = {}
for i in range(10):
writer = tfrecord.TFRecordWriter(f"/tmp/data_{i}.tfrecord")
for j in range(np.random.randint(5, 10)):
writer.write({"data": (np.random.rand(2), "float")})
writer.close()
splits[str(i)] = 1
dataset = tfrecord.torch.dataset.MultiTFRecordDataset("/tmp/data_{}.tfrecord", None, splits)
loader = iter(torch.utils.data.DataLoader(dataset, batch_size=3))
for i in range(100000):
data = next(loader)
print(data)
I also have the same issue, the iterator hangs indefinitely:
dataset_name = "/path-to-tfrecords/{}.tfrecords"
dataset = MultiTFRecordDataset(dataset_name, index_pattern = None, splits, description, infinite=False)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
data = next(iter(loader)) ## when replaced by for-loop with n iterations, it works