Error when loading string feature
hcarlens opened this issue · 2 comments
hcarlens commented
I'm getting some errors when loading a byte feature from Waymo's open data motion dataset (I can't share the data because of the license, but it's available here: https://waymo.com/open/).
Minimal reproducible example:
import torch
from tfrecord.torch.dataset import TFRecordDataset
tfrecord_path = "uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000"
index_path = None
scenario_features = {
'scenario/id': 'byte',
}
dataset = TFRecordDataset(tfrecord_path, index_path, scenario_features)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
Error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-32-b161469c8e99> in <module>
11 loader = torch.utils.data.DataLoader(dataset, batch_size=32)
12
---> 13 data = next(iter(loader))
14 print(data)
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
515 if self._sampler_iter is None:
516 self._reset()
--> 517 data = self._next_data()
518 self._num_yielded += 1
519 if self._dataset_kind == _DatasetKind.Iterable and \
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
555 def _next_data(self):
556 index = self._next_index() # may raise StopIteration
--> 557 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
558 if self._pin_memory:
559 data = _utils.pin_memory.pin_memory(data)
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
33 else:
34 data = next(self.dataset_iter)
---> 35 return self.collate_fn(data)
36
37
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
71 return batch
72 elif isinstance(elem, container_abcs.Mapping):
---> 73 return {key: default_collate([d[key] for d in batch]) for key in elem}
74 elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
75 return elem_type(*(default_collate(samples) for samples in zip(*batch)))
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in <dictcomp>(.0)
71 return batch
72 elif isinstance(elem, container_abcs.Mapping):
---> 73 return {key: default_collate([d[key] for d in batch]) for key in elem}
74 elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
75 return elem_type(*(default_collate(samples) for samples in zip(*batch)))
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
61 raise TypeError(default_collate_err_msg_format.format(elem.dtype))
62
---> 63 return default_collate([torch.as_tensor(b) for b in batch])
64 elif elem.shape == (): # scalars
65 return torch.as_tensor(batch)
~/.venv/wenv/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
53 storage = elem.storage()._new_shared(numel)
54 out = elem.new(storage)
---> 55 return torch.stack(batch, 0, out=out)
56 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
57 and elem_type.__name__ != 'string_':
RuntimeError: stack expects each tensor to be equal size, but got [16] at entry 0 and [14] at entry 4
From TensorFlow I'm able to load it using this definition:
'scenario/id': tf.io.FixedLenFeature((), tf.string, default_value=None),
vahidk commented
You can't create batches with variable length tensors. To create fixed length tensors simply pass a transform function to TFRecordDataset and do your preprocessing there. See readme for example.
hcarlens commented
Thanks!