vahidk/tfrecord

Error when loading string feature

hcarlens opened this issue · 2 comments

I'm getting some errors when loading a byte feature from Waymo's open data motion dataset (I can't share the data because of the license, but it's available here: https://waymo.com/open/).

Minimal reproducible example:

import torch
from tfrecord.torch.dataset import TFRecordDataset

tfrecord_path = "uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000"
index_path = None
scenario_features = {
        'scenario/id': 'byte',
}

dataset = TFRecordDataset(tfrecord_path, index_path, scenario_features)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-32-b161469c8e99> in <module>
     11 loader = torch.utils.data.DataLoader(dataset, batch_size=32)
     12 
---> 13 data = next(iter(loader))
     14 print(data)

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    555     def _next_data(self):
    556         index = self._next_index()  # may raise StopIteration
--> 557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
    559             data = _utils.pin_memory.pin_memory(data)

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     33         else:
     34             data = next(self.dataset_iter)
---> 35         return self.collate_fn(data)
     36 
     37 

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     71         return batch
     72     elif isinstance(elem, container_abcs.Mapping):
---> 73         return {key: default_collate([d[key] for d in batch]) for key in elem}
     74     elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
     75         return elem_type(*(default_collate(samples) for samples in zip(*batch)))

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in <dictcomp>(.0)
     71         return batch
     72     elif isinstance(elem, container_abcs.Mapping):
---> 73         return {key: default_collate([d[key] for d in batch]) for key in elem}
     74     elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
     75         return elem_type(*(default_collate(samples) for samples in zip(*batch)))

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     61                 raise TypeError(default_collate_err_msg_format.format(elem.dtype))
     62 
---> 63             return default_collate([torch.as_tensor(b) for b in batch])
     64         elif elem.shape == ():  # scalars
     65             return torch.as_tensor(batch)

~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     53             storage = elem.storage()._new_shared(numel)
     54             out = elem.new(storage)
---> 55         return torch.stack(batch, 0, out=out)
     56     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
     57             and elem_type.__name__ != 'string_':

RuntimeError: stack expects each tensor to be equal size, but got [16] at entry 0 and [14] at entry 4

From TensorFlow I'm able to load it using this definition:

'scenario/id': tf.io.FixedLenFeature((), tf.string, default_value=None),

You can't create batches with variable length tensors. To create fixed length tensors simply pass a transform function to TFRecordDataset and do your preprocessing there. See readme for example.

Thanks!