pytorch/data

Loading `.tfrecords` files that require a deserialization method

fteufel opened this issue ยท 0 comments

๐Ÿ› Describe the bug

Hi,

I have a dataset in TFRecords format and am trying to move to TorchData's API for loading tfrecords files.
This is the minimal example:

datapipe1 = IterableWrapper(['path/to/my/tfrecords/file.tfrecords'])
datapipe2 = FileOpener(datapipe1, mode="b")
tfrecord_loader_dp = datapipe2.load_from_tfrecord()

for d in tfrecord_loader_dp:
   pass

It fails, as the datapipe does not know how to properly deserialize the tfrecord file.

File ~/.conda/envs/bend/lib/python3.10/site-packages/torchdata/datapipes/iter/util/tfrecordloader.py:245, in TFRecordLoaderIterDataPipe.__iter__(self)
    243 pathname, data_stream = data
    244 try:
--> 245     for example_bytes in iterate_tfrecord_file(data_stream):
    246         example = example_pb2.SequenceExample()  # type: ignore
    247         example.ParseFromString(example_bytes)  # type: ignore

File ~/.conda/envs/bend/lib/python3.10/site-packages/torchdata/datapipes/iter/util/tfrecordloader.py:83, in iterate_tfrecord_file(data)
     81 (length,) = struct.unpack("<Q", length_bytes)
     82 if length > len(data_bytes):
---> 83     data_bytes = data_bytes.zfill(int(length * 1.5))
     84 data_bytes_view = memoryview(data_bytes)[:length]
     85 if data.readinto(data_bytes_view) != length:

OverflowError: Python int too large to convert to C ssize_t
This exception is thrown by __iter__ of TFRecordLoaderIterDataPipe(datapipe=FileOpenerIterDataPipe, length=-1, spec=None)

In the legacy tensorflow codebase, I would have to specify a function to deserialize the tfrecord, by doing

import tensorflow as tf
import tensorflow_datasets as tfds

dataset = tf.data.Dataset.from_tensor_slices(['path/to/my/tfrecords/file.tfrecords'])
dataset = dataset.interleave(lambda fp: tf.data.TFRecordDataset(fp, compression_type=compression_type), cycle_length=1, block_length=1, num_parallel_calls=tf.data.AUTOTUNE)

features = tfds.features.FeaturesDict.from_json(json.load(json_file)) # this file contains info about the .tfrecords file i'm trying to load
dataset = dataset.map(features.deserialize_example, num_parallel_calls=tf.data.AUTOTUNE)

iterator = dataset.as_numpy_iterator()
for d in iterator:
    pass #this works, returning a dict of tf tensors

The problem is basically that I have to deserialize the tfrecord, but I can't apply anything to the TFRecordLoaderIterDataPipe before it fails.

Is there a workaround? I tried just wrapping the tensorflow dataset object in an IterableWrapper, but the tensorflow dataset can't be pickled so fails in DataLoader2.

Thanks!

Versions

Collecting environment information...
PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.31

Python version: 3.10.12 (main, Jul 5 2023, 18:54:27) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1027-aws-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 16
On-line CPU(s) list: 0-15
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
Stepping: 7
CPU MHz: 2499.994
BogoMIPS: 4999.98
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 256 KiB
L1i cache: 256 KiB
L2 cache: 8 MiB
L3 cache: 35.8 MiB
NUMA node0 CPU(s): 0-15
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] torch==2.0.1
[pip3] torchdata==0.6.1
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] torch 2.0.1 pypi_0 pypi
[conda] torchdata 0.6.1 pypi_0 pypi
[conda] torchvision 0.15.2 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi