
How to load data in multiple processes?

@gfjiangly The same problem with you. I think current version cannot work for multiple GPUs. The reason is that IterableDataset (here it is TFRecordDataset) cannot work with DistributedSampler in PyTorch. Maybe we can create a special DistributedSampler for IterableDataset later. However, you can load data using multiple data loading threads with num_workers > 1 for a single CPU.

@gfjiangly @JerryLead The way I solve this problem is to shuffle and distribute TFRecord files to different GPUs evenly before each epoch. The problem is, how can to handle these 2 situations:

  1. The number of TFRecord files cannot be divided by number of GPUs, i.e. 16 TFRecord files with 10 GPUs.
  2. The number of samples cannot be divided by number of GPUs, i.e. 1000 samples with 16 GPUs. In other word, how to implement drop_last=True.

Is the IterableDataset the reason why, when usingtfrecord.torch.dataset.MultiTFRecordDataset with torch_xla.distributed.parallel_loader.ParallelLoader gives the following type error-

Exception in device=TPU:0: object of type 'MultiTFRecordDataset' has no len()
What can be done to resolve this error?

@linkun-1998 Yes, IterableDataset does not has __len__ method by default, so len(dataset) is unavailable for it. You must add __len__ method by your self.

@DelightRun awesome, but how can you define a __len__ function when you are streaming from multiple TFRecords, using tfrecord.torch.dataset.MultiTFRecordDataset ?

It's pretty straightforward when the index is available. You can just implement a new "RandomAccessMultiTFRecordDataset" class that inherits from and change the logic. PR are welcome.

I tried to implement newMultiTFRecordDataset which inherits from as follows:

class newMultiTFRecordDataset():
      def __init__(self,
                  data_pattern: str,
                  index_pattern: typing.Union[str, None],
                  splits: typing.Dict[str, float],
                  description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
                  shuffle_queue_size: typing.Optional[int] = None,
                  transform: typing.Callable[[dict], typing.Any] = None) -> None:
        super(newMultiTFRecordDataset, self).__init__()
        self.data_pattern = data_pattern
        self.index_pattern = index_pattern
        self.splits = splits
        self.description = description
        self.shuffle_queue_size = shuffle_queue_size
        self.transform = transform

      def __len__(self):
        index_len = 0
        for split in self.splits.keys():
          index_len += len(np.loadtxt(self.index_pattern.replace('{}', str(split)), dtype=np.float32)[:, 0])
        return index_len

      def __getitem__(self, index):
        worker_info =
        if worker_info is not None:
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        it = tfrecord.reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        data = next(it)
        return data

I try to use next(iter(dataloader)) of the following dataset which also just works fine.
But when I try to implement in a model with TPU device in colab, I get the following error mentioned below:

The implementation of the following training is as follows:

SERIAL_EXEC = xmp.MpSerialExecutor()

class Network(nn.Module):

  def __init__(self):
    super(Network, self).__init__()
    self.conv1 = nn.Conv2d(3, 10, kernel_size = 5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size = 5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 5)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

#Only instantiate model weight once in memory
net = xmp.MpModelWrapper(Network())

def train_flowers():
  def get_datasets():

    class newMultiTFRecordDataset():
      def __init__(self,
                  data_pattern: str,
                  index_pattern: typing.Union[str, None],
                  splits: typing.Dict[str, float],
                  description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
                  shuffle_queue_size: typing.Optional[int] = None,
                  transform: typing.Callable[[dict], typing.Any] = None) -> None:
        super(newMultiTFRecordDataset, self).__init__()
        self.data_pattern = data_pattern
        self.index_pattern = index_pattern
        self.splits = splits
        self.description = description
        self.shuffle_queue_size = shuffle_queue_size
        self.transform = transform

      def __len__(self):
        index_len = 0
        for split in self.splits.keys():
          index_len += len(np.loadtxt(self.index_pattern.replace('{}', str(split)), dtype=np.float32)[:, 0])
        return index_len

      def __getitem__(self, index):
        worker_info =
        if worker_info is not None:
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        it = tfrecord.reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        data = next(it)
        return data

    filenames = os.listdir(FLAGS['datadir'])
    filenames = [file_[:-6] for file_ in filenames]
    #split training filenames
    validation_filenames = list(random.sample(filenames, int(len(filenames)*FLAGS['test_split'])))
    training_filenames = [filename for filename in filenames if filename not in validation_filenames]
    #getting tfrecords pattern
    tfrec_pattern = os.path.join(FLAGS['datadir'], '{}.tfrec')
    #getting index pattern
    index_pattern = os.path.join(FLAGS['indexdir'], '{}.idx')

    def primary_transforms(features):
      features['image'] = cv2.resize(cv2.imdecode(features["image"], -1), FLAGS['image_size'])
      features["image"] = cv2.cvtColor(features["image"] , cv2.COLOR_BGR2RGB)
      features["image"] = np.moveaxis(features["image"], -1, 0)
      features['class'] = np.squeeze(np.eye(FLAGS['num_classes'])[np.array([features["class"]]).reshape(-1)])
      return features

    description = { "image": "byte",
                  "class": "int"}
    train_samp_split = {}
    for file_ in training_filenames:
      train_samp_split[file_] = 1/len(training_filenames)

    val_samp_split = {}
    for file_ in validation_filenames:
      val_samp_split[file_] = 1/len(validation_filenames)
    train_dataset = newMultiTFRecordDataset(tfrec_pattern,
    val_dataset = newMultiTFRecordDataset(tfrec_pattern,

    return (train_dataset, val_dataset)

  train_dataset, test_dataset = get_datasets()

  train_sampler =
  train_loader =
      sampler = train_sampler,
  test_loader =
  #Scale learning rate to world size
  lr = FLAGS['learning_rate']*xm.xrt_world_size()

  #Get loss function, optimizer, and model
  device = xm.xla_device()
  model =
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    for x, batch in enumerate(loader):
      data, target = batch['image'], batch['class']
      output = model(data)
      loss = loss_fn(output, target)
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    data, pred, target = None, None, None
    for batch in loader:
      data, target = batch['image'], batch['class']
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct/total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
          xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs']+1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

def _mp_fn(rank, flags):
  global FLAGS 
  FLAGS = flags
  accuracy, data, pred, target = train_flowers()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot
    plot_results(data.cpu(), pred.cpu(), target.cpu())
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')

What could be the possible error?

@linkun-1998 Due to company's compliance reason, I cannot upload the full code. This is the core part of MultiTFRecordDataset:

class MultiTFRecordDataset(
    """Parse multiple (generic) TFRecords datasets into an `IterableDataset`
    object, which contain `np.ndarrays`s.

    data_pattern: str
        Input data path pattern.

    index_pattern: str or None
        Input index path pattern.

    splits: dict
        Dictionary of (key, value) pairs, where the key is used to
        construct the data and index path(s) and the value determines
        the contribution of each split to the batch.

    description: list or dict of str, optional, default=None
        List of keys or dict of (key, value) pairs to extract from each
        record. The keys represent the name of the features and the
        values ("byte", "float", or "int") correspond to the data type.
        If dtypes are provided, then they are verified against the
        inferred type for compatibility purposes. If None (default),
        then all features contained in the file are extracted.

    is_sequence: bool, optional, default=False
        TFRecord example type. Using tf.train.SequenceExample if
        is_sequence=True, else tf.train.Example.

    shuffle_queue_size: int, optional, default=None
        Length of buffer. Determines how many records are queued to
        sample from.

    transform : a callable, default = None
        A function that takes in the input `features` i.e the dict
        provided in the description, transforms it and returns a
        desirable output.


    def __init__(self,
                 data_pattern: str,
                 index_pattern: typing.Union[str, None],
                 splits: typing.Dict[str, float],
                 description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
                 is_sequence: bool = False,
                 shuffle_queue_size: typing.Optional[int] = None,
                 transform: typing.Callable[[dict], typing.Any] = None) -> None:
        super(MultiTFRecordDataset, self).__init__()
        self.data_pattern = data_pattern
        self.index_pattern = index_pattern
        self.splits = splits
        self.description = description
        self.is_sequence = is_sequence
        self.shuffle_queue_size = shuffle_queue_size
        self.transform = transform

        if self.index_pattern is not None:
            self.num_samples = sum(
                sum(1 for _ in open(self.index_pattern.format(split)))
                for split in self.splits
            self.num_samples = None

    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
            raise NotImplementedError()

    def __iter__(self):
        worker_info =
        if worker_info is not None:
            shard =, worker_info.num_workers
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
            shard = None
        it = reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description, self.is_sequence, shard)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        return it

@DelightRun You just implemented the len function which was required. Nice! Thanks. But Can you just gimme the reason why my code dosenot works?

@DelightRun Moreover I get the same error after adding a __len__() function to MultiTFRecordDataset.

@linkun-1998 Were you able to solve it?

does someone solve the this issue?