pytorch/data

what is the right way to serialize DataLoader2 so that pipeline with shuffle can resume from the right place?

zhengwy888 opened this issue ยท 2 comments

๐Ÿ› Describe the bug

I tried all these versions, the only version that worked was the last one, but it's too hacky. Is there a better way?

    dp = IterableWrapper(list(range(20)))
    dp = dp.shuffle()
    items = []
    rs = InProcessReadingService()
    dl = DataLoader2(dp, reading_service=rs)
    iter1 = iter(dl)
    for _ in range(4):
        next(iter1)

    # 16 elements left in dl
    state = dl.state_dict()
    dl2 = DataLoader2.from_state(state, reading_service=rs)
    # assert len(list(dl2)) == 20 - 4  # got 20

    dp2 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    # assert len(list(dp2)) == 20 - 4 # got 20

    dp3 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    _simple_graph_snapshot_restoration(dp3, dp3._number_of_samples_yielded)
    ret3 = list(dp3)
    assert len(ret3) == 20 - 4
    # but content is not the same

    dl4 = DataLoader2.from_state(state, reading_service=rs)
    _simple_graph_snapshot_restoration(dl4.datapipe, dl.datapipe._number_of_samples_yielded)
    ret4 = list(dl4)
    assert len(ret4) == 20 - 4
    # but content is not the same

    dp5 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    pipes = get_all_pipes(dp5)
    for pipe in pipes:
        if isinstance(pipe, ShufflerIterDataPipe):
            buffer_cache = pipe._buffer[:]
            assert len(buffer_cache) == 20 - 4
            rng_state = pipe._rng.getstate()
    _simple_graph_snapshot_restoration(dp5, dl.datapipe._number_of_samples_yielded)
    dp5._buffer = buffer_cache[:]
    dp5._rng.setstate(rng_state)
    it5 = iter(dp5)
    ret5 = list(it5)
    assert len(ret5) == 20 - 4

    expected = list(iter1)
    # ret5 is the only method that worked
    # assert ret3 == expected
    # assert ret4 == expected
    assert ret5 == expected

Versions

PyTorch version: 2.0.0a0+gite9ebda2
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 12.0.1 (https://github.com/conda-forge/clangdev-feedstock d44358f44aef33e9fa7c5f93e2481ee8f1a04ab6)
CMake version: version 3.19.1
Libc version: glibc-2.31

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-64-generic-x86_64-with-glibc2.10
Is CUDA available: False
CUDA runtime version: 12.0.140
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] mypy-protobuf==3.3.0
[pip3] numpy==1.23.5
[pip3] pytorch3d==0.6.2
[pip3] torch==2.0.1+1684801906.cuda120.cudnn891.nccl218.ap
[pip3] torch-mlir==1684442443
[pip3] torch-scatter==2.1.0
[pip3] torch-tb-profiler==0.4.1
[pip3] torchdata==0.7.0.dev20230601
[pip3] torchfile==0.1.0
[pip3] torchvision==0.15.1a0+42759b1
[conda] magma-cuda121             2.6.1                         1    pytorch
[conda] mkl                       2020.4             h726a3e6_304    conda-forge
[conda] mkl-include               2023.1.0         h84fe81f_48680    conda-forge
[conda] numpy                     1.23.5           py38h7042d01_0    conda-forge
[conda] pytorch3d                 0.6.2                    pypi_0    pypi
[conda] torch                     2.0.1+1684801906.cuda120.cudnn891.nccl218.ap          pypi_0    pypi
[conda] torch-mlir                1684442443               pypi_0    pypi
[conda] torch-scatter             2.1.0                    pypi_0    pypi
[conda] torch-tb-profiler         0.4.1                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchvision               0.15.1a0+42759b1          pypi_0    pypi
ejguan commented

I think you can rely on the dlv2.state_dict() to get the state. But, it's still in prototyping mode it might has some Errors.

but it didn't work, see example 1.