pytorch/data

Memory spikes with large DataPipes

sesquipedalianist opened this issue · 11 comments

🐛 Describe the bug

I’ve noticed large “spikes” in memory usage at the start of epochs when using IterDataPipes with attributes that take a lot of memory. These can cause my training jobs to fail with out-of-memory errors.

Here’s a minimal example to reproduce:

import torch
import torchdata.datapipes as dp
from torch.utils.data import DataLoader
from tqdm import tqdm

NUM_WORKERS = 2
NUM_ITEMS = 100
ITEM_SIZE = 5_000_000


def get_item(x):
    return torch.rand(ITEM_SIZE)


def get_datapipe():
    datapipe = dp.iter.IterableWrapper(range(NUM_ITEMS))
    datapipe = datapipe.map(get_item)
    datapipe = datapipe.in_memory_cache()
    return datapipe


def main():
    datapipe = get_datapipe()
    dataloader = DataLoader(
        datapipe, batch_size=1, num_workers=NUM_WORKERS, persistent_workers=True
    )

    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        for _ in tqdm(dataloader, total=NUM_ITEMS * NUM_WORKERS):
            pass


if __name__ == "__main__":
    main()

The memory usage (logged with psutil) looks like this:

datapipe_memory_spikes

Here, start_epoch indicates the start of an epoch and first_iter corresponds to the first time each epoch we reach the pass statement in the dataloader loop. (To simplify the example code above I removed the code that logs start_epoch and first_iter. I logged the memory usage from a separate process.)

After some debugging, I can say that the memory spikes occur during the traversal of the graph that occurs in torch/utils/data/graph_settings.py::apply_random_seed() at the beginning of each epoch. Disabling the body of this function removes the memory spikes.

The spikes seem to be caused by the pickling in https://github.com/pytorch/pytorch/blob/99ded8bbcea896b02f1c0babb055329c503ca95e/torch/utils/data/graph.py#L23
The code here defines f = io.BytesIO() and pickles to f. If there are large datapipes to be pickled, it makes sense that the memory usage will blow up quickly and then fall again when f goes out of scope.

I tried replacing f = io.BytesIO() with f = open(os.devnull, "wb") (and adding f.close() at the end of the function). This didn’t eliminate the memory spikes but it did make them a bit smaller.

A few notes:

  • it’s not necessary to use .in_memory_cache() to see these spikes; it seems that any datapipe that occupies a lot of memory will cause them
  • I verified that the spikes do not occur with similar Dataset and IterableDataset subclasses
  • the spikes do not occur if we remove the Dataloader and iterate directly over the datapipe.

Versions

I have tested the above with both

  • torch 1.13.1 and torchdata 0.5.1 (my development environment)
  • torch 2.0 and torchdata 0.6.0

I observed the same behavior in both cases.

ejguan commented

Have you tried the same pipeline with DataLoader2 and MultiprocessingReadingService?

I just tried it (I had previously tried DataLoader2 but perhaps with torch 1.13.1) and the spikes still occur. This makes sense to me because it seems the datapipe graph is still traversed in the same way.

I posted another MemoryError that may be related here:

https://discuss.pytorch.org/t/torchdata-w-ddp-start-of-epoch-2-get-memoryerror/179523

My MemoryError also occurs at the start of the epoch while using DDP and distributed multiprocessing. It seems to depend on the size of the shuffles that I put into the datapipe (one for files, one for fixed length decoding, one for augmentations), as I got through 9 epochs before reaching the OOM error most recently.

It's really weird. I use < 150GB of RAM during training and my 500 GB of RAM gets overwhelmed at the beginning of epoch 2. I considered shutting down and restarting the pipe to resolve.

ejguan commented

I just tried it (I had previously tried DataLoader2 but perhaps with torch 1.13.1) and the spikes still occur. This makes sense to me because it seems the datapipe graph is still traversed in the same way.

@sesquipedalianist
Thanks for reporting. So, it happens after the first epoch. And, due to in_memory_cache, the memory usage when traversing through the DataPipes becomes more significant. This might requires some in-depth investigation on how to properly remove the inner object like buffer, etc during traversing as we only need DataPipe.
One approach might be adding a wrapper around those objects to prevent them going through pickle during traverse.

ejguan commented

@andrew-bydlon
It's kind weird to me because shuffler should not contain any buffer during reset at the epoch 2, which is the only place that can hold a few data. It means the peak memory consumption should not be expected. It would be great if you can share a minimum reproducible script for us to reproduce and debug.

It's difficult to provide code for this purpose as the code is property of a large corp. Some other notes and expansion of the other thoughts:

Mention of shuffling cause memory increase: pytorch/pytorch#13246 (comment)

I am generally storing data in tars in the form (arbitrary length audio, {labels: tensor, dataset: string, ID: string})

And here is an expansion of my list of pipes:

  1. FileLister: About 800 tars each with 1024 or 2048 samples.
  2. shuffle
  3. sharding_filter
  4. metadata and fixed length audio extraction: Yield out audios from the initial audio of the correct length. 1 : Many.
  5. shuffle
  6. augment samples: Apply various augmentations to the audios, including background addition and reverberation.
  7. shuffle

For now I have solved my issue by monkeypatching torch.utils.data.graph_settings.apply_random_seed as follows:

def apply_random_seed_overwrite(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
    return datapipe


torch.utils.data.graph_settings.apply_random_seed = apply_random_seed_overwrite

This effectively disables apply_random_seed, which for my purposes is not a problem since in training I am not providing a seed and in validation/testing I am not shuffling. Doing this completely eliminates the memory spikes (since we no longer traverse the datapipe at the beginning of each epoch).

@andrew-bydlon are you saving anything in memory (like audio samples?). That would likely cause the same issue as I was having.

I'm not saving anything in memory other than prefetching. I'm using iterable datapipes to do all of the above per recommendations on the homepage. These default to prefetch factors of 10. The augmentation operations take some compute, but all of this happens at the start of epoch 2 (going from 20% memory -> 100%), so it seems extremely unexpected.

Thank you both for your help. I have finally deep-dived this topic and made an issue:

#1185

There is a lot of talk about Memory Leaks in the Issues. I really like the DataLoader2 API, but will be temporarily switching back to DL1 because of the issues that I mention.

For now I have solved my issue by monkeypatching torch.utils.data.graph_settings.apply_random_seed as follows:

def apply_random_seed_overwrite(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
    return datapipe


torch.utils.data.graph_settings.apply_random_seed = apply_random_seed_overwrite

This effectively disables apply_random_seed, which for my purposes is not a problem since in training I am not providing a seed and in validation/testing I am not shuffling. Doing this completely eliminates the memory spikes (since we no longer traverse the datapipe at the beginning of each epoch).

I tried this out without success. Glad it worked for you!