How to load data in multiple processes?
gfjiangly opened this issue · 12 comments
How to load data in multiple processes?
@gfjiangly The same problem with you. I think current version cannot work for multiple GPUs. The reason is that IterableDataset
(here it is TFRecordDataset
) cannot work with DistributedSampler in PyTorch. Maybe we can create a special DistributedSampler
for IterableDataset
later. However, you can load data using multiple data loading threads with num_workers > 1
for a single CPU.
@gfjiangly @JerryLead The way I solve this problem is to shuffle and distribute TFRecord files to different GPUs evenly before each epoch. The problem is, how can to handle these 2 situations:
- The number of TFRecord files cannot be divided by number of GPUs, i.e. 16 TFRecord files with 10 GPUs.
- The number of samples cannot be divided by number of GPUs, i.e. 1000 samples with 16 GPUs. In other word, how to implement drop_last=True.
Is the IterableDataset
the reason why, when usingtfrecord.torch.dataset.MultiTFRecordDataset
with torch_xla.distributed.parallel_loader.ParallelLoader
gives the following type error-
Exception in device=TPU:0: object of type 'MultiTFRecordDataset' has no len()
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 231, in _start_fn
fn(gindex, *args)
File "<ipython-input-25-bd5e4111d32a>", line 182, in _mp_fn
accuracy, data, pred, target = train_mnist()
File "<ipython-input-25-bd5e4111d32a>", line 165, in train_mnist
para_loader = pl.ParallelLoader(train_loader, [device])
File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/parallel_loader.py", line 80, in __init__
self._per_device_samples = len(loader) // len(devices)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 315, in __len__
length = self._IterableDataset_len_called = len(self.dataset)
TypeError: object of type 'MultiTFRecordDataset' has no len()
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-25-bd5e4111d32a> in <module>()
184 # Retrieve tensors that are on TPU core 0 and plot
185 plot_results(data.cpu(), pred.cpu(), target.cpu())
--> 186 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')
2 frames
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
111 raise Exception(
112 "process %d terminated with exit code %d" %
--> 113 (error_index, exitcode)
114 )
115
Exception: process 0 terminated with exit code 17
What can be done to resolve this error?
@linkun-1998 Yes, IterableDataset does not has __len__
method by default, so len(dataset)
is unavailable for it. You must add __len__
method by your self.
@DelightRun awesome, but how can you define a __len__
function when you are streaming from multiple TFRecords, using tfrecord.torch.dataset.MultiTFRecordDataset
?
It's pretty straightforward when the index is available. You can just implement a new "RandomAccessMultiTFRecordDataset" class that inherits from torch.utils.data.Dataset and change the logic. PR are welcome.
I tried to implement newMultiTFRecordDataset
which inherits from torch.utils.data.Dataset
as follows:
class newMultiTFRecordDataset():
def __init__(self,
data_pattern: str,
index_pattern: typing.Union[str, None],
splits: typing.Dict[str, float],
description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
shuffle_queue_size: typing.Optional[int] = None,
transform: typing.Callable[[dict], typing.Any] = None) -> None:
super(newMultiTFRecordDataset, self).__init__()
self.data_pattern = data_pattern
self.index_pattern = index_pattern
self.splits = splits
self.description = description
self.shuffle_queue_size = shuffle_queue_size
self.transform = transform
def __len__(self):
index_len = 0
for split in self.splits.keys():
index_len += len(np.loadtxt(self.index_pattern.replace('{}', str(split)), dtype=np.float32)[:, 0])
return index_len
def __getitem__(self, index):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
it = tfrecord.reader.multi_tfrecord_loader(
self.data_pattern, self.index_pattern, self.splits, self.description)
if self.shuffle_queue_size:
it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
if self.transform:
it = map(self.transform, it)
data = next(it)
return data
I try to use next(iter(dataloader))
of the following dataset which also just works fine.
But when I try to implement in a model with TPU device in colab, I get the following error mentioned below:
Exception Traceback (most recent call last)
<ipython-input-10-51f2165c039a> in <module>()
175 # Retrieve tensors that are on TPU core 0 and plot
176 plot_results(data.cpu(), pred.cpu(), target.cpu())
--> 177 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')
2 frames
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
106 raise Exception(
107 "process %d terminated with signal %s" %
--> 108 (error_index, name)
109 )
110 else:
Exception: process 0 terminated with signal SIGABRT
The implementation of the following training is as follows:
SERIAL_EXEC = xmp.MpSerialExecutor()
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size = 5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size = 5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 5)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
#Only instantiate model weight once in memory
net = xmp.MpModelWrapper(Network())
print(net)
def train_flowers():
torch.manual_seed(1)
----------------------------------------------------------------------------------------------------------------------------
def get_datasets():
class newMultiTFRecordDataset():
def __init__(self,
data_pattern: str,
index_pattern: typing.Union[str, None],
splits: typing.Dict[str, float],
description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
shuffle_queue_size: typing.Optional[int] = None,
transform: typing.Callable[[dict], typing.Any] = None) -> None:
super(newMultiTFRecordDataset, self).__init__()
self.data_pattern = data_pattern
self.index_pattern = index_pattern
self.splits = splits
self.description = description
self.shuffle_queue_size = shuffle_queue_size
self.transform = transform
def __len__(self):
index_len = 0
for split in self.splits.keys():
index_len += len(np.loadtxt(self.index_pattern.replace('{}', str(split)), dtype=np.float32)[:, 0])
return index_len
def __getitem__(self, index):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
it = tfrecord.reader.multi_tfrecord_loader(
self.data_pattern, self.index_pattern, self.splits, self.description)
if self.shuffle_queue_size:
it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
if self.transform:
it = map(self.transform, it)
data = next(it)
return data
filenames = os.listdir(FLAGS['datadir'])
filenames = [file_[:-6] for file_ in filenames]
#split training filenames
random.seed(1)
validation_filenames = list(random.sample(filenames, int(len(filenames)*FLAGS['test_split'])))
training_filenames = [filename for filename in filenames if filename not in validation_filenames]
#getting tfrecords pattern
tfrec_pattern = os.path.join(FLAGS['datadir'], '{}.tfrec')
#getting index pattern
index_pattern = os.path.join(FLAGS['indexdir'], '{}.idx')
def primary_transforms(features):
features['image'] = cv2.resize(cv2.imdecode(features["image"], -1), FLAGS['image_size'])
features["image"] = cv2.cvtColor(features["image"] , cv2.COLOR_BGR2RGB)
features["image"] = np.moveaxis(features["image"], -1, 0)
features['class'] = np.squeeze(np.eye(FLAGS['num_classes'])[np.array([features["class"]]).reshape(-1)])
return features
description = { "image": "byte",
"class": "int"}
train_samp_split = {}
for file_ in training_filenames:
train_samp_split[file_] = 1/len(training_filenames)
val_samp_split = {}
for file_ in validation_filenames:
val_samp_split[file_] = 1/len(validation_filenames)
train_dataset = newMultiTFRecordDataset(tfrec_pattern,
index_pattern,
train_samp_split,
description,
transform=primary_transforms)
val_dataset = newMultiTFRecordDataset(tfrec_pattern,
index_pattern,
val_samp_split,
description,
transform=primary_transforms)
return (train_dataset, val_dataset)
-----------------------------------------------------------------------------------------------------------------------------
train_dataset, test_dataset = get_datasets()
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS['batch_size'],
sampler = train_sampler,
num_workers=FLAGS['num_workers'],
drop_last=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=FLAGS['batch_size'],
shuffle=False,
num_workers=FLAGS['num_workers'],
drop_last=True
)
----------------------------------------------------------------------------------------------------------------------------
#Scale learning rate to world size
lr = FLAGS['learning_rate']*xm.xrt_world_size()
#Get loss function, optimizer, and model
device = xm.xla_device()
model = net.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
loss_fn = nn.NLLLoss()
def train_loop_fn(loader):
tracker = xm.RateTracker()
model.train()
for x, batch in enumerate(loader):
optimizer.zero_grad()
data, target = batch['image'], batch['class']
output = model(data)
loss = loss_fn(output, target)
loss.backward()
xm.optimizer_step(optimizer)
tracker.add(FLAGS['batch_size'])
if x % FLAGS['log_steps'] == 0:
print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
xm.get_ordinal(), x, loss.item(), tracker.rate(),
tracker.global_rate(), time.asctime()), flush=True)
def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
data, pred, target = None, None, None
for batch in loader:
data, target = batch['image'], batch['class']
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
total_samples += data.size()[0]
accuracy = 100.0 * correct/total_samples
print('[xla:{}] Accuracy={:.2f}%'.format(
xm.get_ordinal(), accuracy), flush=True)
return accuracy, data, pred, target
# Train and eval loops
accuracy = 0.0
data, pred, target = None, None, None
for epoch in range(1, FLAGS['num_epochs']+1):
para_loader = pl.ParallelLoader(train_loader, [device])
train_loop_fn(para_loader.per_device_loader(device))
xm.master_print("Finished training epoch {}".format(epoch))
para_loader = pl.ParallelLoader(test_loader, [device])
accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))
if FLAGS['metrics_debug']:
xm.master_print(met.metrics_report(), flush=True)
return accuracy, data, pred, target
def _mp_fn(rank, flags):
global FLAGS
FLAGS = flags
torch.set_default_tensor_type('torch.FloatTensor')
torch.set_default_tensor_type('torch.FloatTensor')
accuracy, data, pred, target = train_flowers()
if rank == 0:
# Retrieve tensors that are on TPU core 0 and plot
plot_results(data.cpu(), pred.cpu(), target.cpu())
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')
What could be the possible error?
@linkun-1998 Due to company's compliance reason, I cannot upload the full code. This is the core part of MultiTFRecordDataset
:
class MultiTFRecordDataset(torch.utils.data.IterableDataset):
"""Parse multiple (generic) TFRecords datasets into an `IterableDataset`
object, which contain `np.ndarrays`s.
Params:
-------
data_pattern: str
Input data path pattern.
index_pattern: str or None
Input index path pattern.
splits: dict
Dictionary of (key, value) pairs, where the key is used to
construct the data and index path(s) and the value determines
the contribution of each split to the batch.
description: list or dict of str, optional, default=None
List of keys or dict of (key, value) pairs to extract from each
record. The keys represent the name of the features and the
values ("byte", "float", or "int") correspond to the data type.
If dtypes are provided, then they are verified against the
inferred type for compatibility purposes. If None (default),
then all features contained in the file are extracted.
is_sequence: bool, optional, default=False
TFRecord example type. Using tf.train.SequenceExample if
is_sequence=True, else tf.train.Example.
shuffle_queue_size: int, optional, default=None
Length of buffer. Determines how many records are queued to
sample from.
transform : a callable, default = None
A function that takes in the input `features` i.e the dict
provided in the description, transforms it and returns a
desirable output.
"""
def __init__(self,
data_pattern: str,
index_pattern: typing.Union[str, None],
splits: typing.Dict[str, float],
description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
is_sequence: bool = False,
shuffle_queue_size: typing.Optional[int] = None,
transform: typing.Callable[[dict], typing.Any] = None) -> None:
super(MultiTFRecordDataset, self).__init__()
self.data_pattern = data_pattern
self.index_pattern = index_pattern
self.splits = splits
self.description = description
self.is_sequence = is_sequence
self.shuffle_queue_size = shuffle_queue_size
self.transform = transform
if self.index_pattern is not None:
self.num_samples = sum(
sum(1 for _ in open(self.index_pattern.format(split)))
for split in self.splits
)
else:
self.num_samples = None
def __len__(self):
if self.num_samples is not None:
return self.num_samples
else:
raise NotImplementedError()
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
shard = worker_info.id, worker_info.num_workers
np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
else:
shard = None
it = reader.multi_tfrecord_loader(
self.data_pattern, self.index_pattern, self.splits, self.description, self.is_sequence, shard)
if self.shuffle_queue_size:
it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
if self.transform:
it = map(self.transform, it)
return it
@DelightRun You just implemented the len
function which was required. Nice! Thanks. But Can you just gimme the reason why my code dosenot works?
@DelightRun Moreover I get the same error after adding a __len__()
function to MultiTFRecordDataset
.
@linkun-1998 Were you able to solve it?
does someone solve the this issue?