pytorch/data

Passing dict in datapipe/dataset will have memory leak problem

Opened this issue ยท 3 comments

๐Ÿ› Describe the bug

Passing dict in datapipe or dataset will casuse memory leak

from copy import deepcopy
import gc

from memory_profiler import profile
import torch
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2


def build_dp1(num_batch):
    item_list = list()
    for idx in range(num_batch):
        item = {
            "id": idx,
            "clean": {
                "path": str(idx),
                "id": idx,
            },
            "noisy":{
                "path": str(idx),
                "id": idx,
            },
        }
        item_list.append(item)
    return IterableWrapper(item_list)

def build_dp2(num_batch):
    item_list = list()
    for idx in range(num_batch):
        item = {
            "id": idx,
            "clean_path": str(idx),
            "clean_id": idx,
            "noisy_path": str(idx),
            "noisy_id": idx,
        }
        item_list.append(item)
    return IterableWrapper(item_list)

def add_audio1(item):
    item["clean"]["audio"] = torch.randn([5000, 10])
    item["noisy"]["audio"] = torch.randn([5000, 10])
    return item

def add_audio2(item):
    new_item = deepcopy(item)
    new_item["clean"]["audio"] = torch.randn([5000, 10])
    new_item["noisy"]["audio"] = torch.randn([5000, 10])
    return new_item

def add_audio3(item):
    item["clean_audio"] = torch.randn([5000, 10])
    item["noisy_audio"] = torch.randn([5000, 10])
    return item

class MyDataset1(torch.utils.data.Dataset):
    def __init__(self, datalen):
        super().__init__()
        self.datalen = datalen

    def __getitem__(self, index):
        item = {
            "id": index,
            "clean_path": str(index),
            "clean_id": index,
            "clean_audio": torch.randn([5000, 10]),
            "noisy_path": str(index),
            "noisy_id": index,
            "noisy_audio": torch.randn([5000, 10]),
        }
        return item

    def __len__(self):
        return self.datalen

class MyDataset2(torch.utils.data.Dataset):
    def __init__(self, datalen):
        super().__init__()
        self.datalen = datalen

    def __getitem__(self, index):
        return torch.randn([5000, 10]), torch.randn([5000, 10])

    def __len__(self):
        return self.datalen

@profile
def datapipe(num_batch):
    dp = build_dp2(num_batch).map(add_audio3)
    dl = DataLoader2(dp)
    for i, batch in enumerate(dl):
        pass
    pass
    del dp, dl

@profile
def dataset1(num_batch):
    ds = MyDataset1(num_batch)
    dl = DataLoader(ds)
    for i, batch in enumerate(dl):
        pass
    pass
    del ds, dl

@profile
def dataset2(num_batch):
    ds = MyDataset2(num_batch)
    dl = DataLoader(ds)
    for i, batch in enumerate(dl):
        pass
    pass
    del ds, dl

num_batch = 1000

gc.collect()
datapipe(num_batch)
gc.collect()
dataset1(num_batch)
gc.collect()
dataset2(num_batch)
gc.collect()


num_batch = 5000

gc.collect()
datapipe(num_batch)
gc.collect()
dataset1(num_batch)
gc.collect()
dataset2(num_batch)
gc.collect()

output:

Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    88    328.1 MiB    328.1 MiB           1   @profile
    89                                         def datapipe(num_batch):
    90    328.4 MiB      0.3 MiB           1       dp = build_dp2(num_batch).map(add_audio3)
    91    330.6 MiB      2.2 MiB           1       dl = DataLoader2(dp)
    92    714.3 MiB    383.6 MiB        1001       for i, batch in enumerate(dl):
    93    714.3 MiB      0.0 MiB        1000           pass
    94    714.3 MiB      0.0 MiB           1       pass
    95    714.3 MiB      0.0 MiB           1       del dp, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    97    714.4 MiB    714.4 MiB           1   @profile
    98                                         def dataset1(num_batch):
    99    714.4 MiB      0.0 MiB           1       ds = MyDataset1(num_batch)
   100    714.4 MiB      0.0 MiB           1       dl = DataLoader(ds)
   101    716.9 MiB      2.5 MiB        1001       for i, batch in enumerate(dl):
   102    716.9 MiB      0.0 MiB        1000           pass
   103    716.9 MiB      0.0 MiB           1       pass
   104    716.9 MiB      0.0 MiB           1       del ds, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   106    716.9 MiB    716.9 MiB           1   @profile
   107                                         def dataset2(num_batch):
   108    716.9 MiB      0.0 MiB           1       ds = MyDataset2(num_batch)
   109    716.9 MiB      0.0 MiB           1       dl = DataLoader(ds)
   110    716.9 MiB      0.0 MiB        1001       for i, batch in enumerate(dl):
   111    716.9 MiB      0.0 MiB        1000           pass
   112    716.9 MiB      0.0 MiB           1       pass
   113    716.9 MiB      0.0 MiB           1       del ds, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    88    716.9 MiB    716.9 MiB           1   @profile
    89                                         def datapipe(num_batch):
    90    717.0 MiB      0.0 MiB           1       dp = build_dp2(num_batch).map(add_audio3)
    91    721.6 MiB      4.6 MiB           1       dl = DataLoader2(dp)
    92   2254.1 MiB   1532.6 MiB        5001       for i, batch in enumerate(dl):
    93   2254.1 MiB      0.0 MiB        5000           pass
    94   2254.1 MiB      0.0 MiB           1       pass
    95   2252.1 MiB     -2.0 MiB           1       del dp, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    97   2251.5 MiB   2251.5 MiB           1   @profile
    98                                         def dataset1(num_batch):
    99   2251.5 MiB      0.0 MiB           1       ds = MyDataset1(num_batch)
   100   2251.5 MiB      0.0 MiB           1       dl = DataLoader(ds)
   101   2251.5 MiB -7642068.4 MiB        5001       for i, batch in enumerate(dl):
   102   2251.5 MiB -7640538.2 MiB        5000           pass
   103    721.3 MiB  -1530.2 MiB           1       pass
   104    721.3 MiB      0.0 MiB           1       del ds, dl


Filename: /home/haoyu.tang/uim_se/test_datapipes.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   106    721.3 MiB    721.3 MiB           1   @profile
   107                                         def dataset2(num_batch):
   108    721.3 MiB      0.0 MiB           1       ds = MyDataset2(num_batch)
   109    721.3 MiB      0.0 MiB           1       dl = DataLoader(ds)
   110    721.3 MiB      0.0 MiB        5001       for i, batch in enumerate(dl):
   111    721.3 MiB      0.0 MiB        5000           pass
   112    721.3 MiB      0.0 MiB           1       pass
   113    721.3 MiB      0.0 MiB           1       del ds, dl

Versions

torch version: 2.0.0
torchdata version: 0.6.0

It is clear that is pasing the dict of tensor memory will leak but list of tensor will not.

I used dict of tensor in my model training, and I found the training faied multiple times all since of memory leak. And I tried to used Tensordict(https://pytorch.org/rl/tensordict/), but it cannot contains the string. I need string during my datapipes passing (str to tensor encode in one of datapipes).

I'm also using dictionaries and see a memory leak. I'm highlighting a different issue but I'm seeing a small increase in usage over time as well:

#1185

@andrew-bydlon I got a temp method to fix the issue. I split my original dict as two dict, one only contains the tensor another without tensor.

Same problem with python dict and tensordict.