Passing dict in datapipe/dataset will have memory leak problem
Opened this issue ยท 3 comments
๐ Describe the bug
Passing dict in datapipe or dataset will casuse memory leak
from copy import deepcopy
import gc
from memory_profiler import profile
import torch
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2
def build_dp1(num_batch):
item_list = list()
for idx in range(num_batch):
item = {
"id": idx,
"clean": {
"path": str(idx),
"id": idx,
},
"noisy":{
"path": str(idx),
"id": idx,
},
}
item_list.append(item)
return IterableWrapper(item_list)
def build_dp2(num_batch):
item_list = list()
for idx in range(num_batch):
item = {
"id": idx,
"clean_path": str(idx),
"clean_id": idx,
"noisy_path": str(idx),
"noisy_id": idx,
}
item_list.append(item)
return IterableWrapper(item_list)
def add_audio1(item):
item["clean"]["audio"] = torch.randn([5000, 10])
item["noisy"]["audio"] = torch.randn([5000, 10])
return item
def add_audio2(item):
new_item = deepcopy(item)
new_item["clean"]["audio"] = torch.randn([5000, 10])
new_item["noisy"]["audio"] = torch.randn([5000, 10])
return new_item
def add_audio3(item):
item["clean_audio"] = torch.randn([5000, 10])
item["noisy_audio"] = torch.randn([5000, 10])
return item
class MyDataset1(torch.utils.data.Dataset):
def __init__(self, datalen):
super().__init__()
self.datalen = datalen
def __getitem__(self, index):
item = {
"id": index,
"clean_path": str(index),
"clean_id": index,
"clean_audio": torch.randn([5000, 10]),
"noisy_path": str(index),
"noisy_id": index,
"noisy_audio": torch.randn([5000, 10]),
}
return item
def __len__(self):
return self.datalen
class MyDataset2(torch.utils.data.Dataset):
def __init__(self, datalen):
super().__init__()
self.datalen = datalen
def __getitem__(self, index):
return torch.randn([5000, 10]), torch.randn([5000, 10])
def __len__(self):
return self.datalen
@profile
def datapipe(num_batch):
dp = build_dp2(num_batch).map(add_audio3)
dl = DataLoader2(dp)
for i, batch in enumerate(dl):
pass
pass
del dp, dl
@profile
def dataset1(num_batch):
ds = MyDataset1(num_batch)
dl = DataLoader(ds)
for i, batch in enumerate(dl):
pass
pass
del ds, dl
@profile
def dataset2(num_batch):
ds = MyDataset2(num_batch)
dl = DataLoader(ds)
for i, batch in enumerate(dl):
pass
pass
del ds, dl
num_batch = 1000
gc.collect()
datapipe(num_batch)
gc.collect()
dataset1(num_batch)
gc.collect()
dataset2(num_batch)
gc.collect()
num_batch = 5000
gc.collect()
datapipe(num_batch)
gc.collect()
dataset1(num_batch)
gc.collect()
dataset2(num_batch)
gc.collect()
output:
Filename: /home/haoyu.tang/uim_se/test_datapipes.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
88 328.1 MiB 328.1 MiB 1 @profile
89 def datapipe(num_batch):
90 328.4 MiB 0.3 MiB 1 dp = build_dp2(num_batch).map(add_audio3)
91 330.6 MiB 2.2 MiB 1 dl = DataLoader2(dp)
92 714.3 MiB 383.6 MiB 1001 for i, batch in enumerate(dl):
93 714.3 MiB 0.0 MiB 1000 pass
94 714.3 MiB 0.0 MiB 1 pass
95 714.3 MiB 0.0 MiB 1 del dp, dl
Filename: /home/haoyu.tang/uim_se/test_datapipes.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
97 714.4 MiB 714.4 MiB 1 @profile
98 def dataset1(num_batch):
99 714.4 MiB 0.0 MiB 1 ds = MyDataset1(num_batch)
100 714.4 MiB 0.0 MiB 1 dl = DataLoader(ds)
101 716.9 MiB 2.5 MiB 1001 for i, batch in enumerate(dl):
102 716.9 MiB 0.0 MiB 1000 pass
103 716.9 MiB 0.0 MiB 1 pass
104 716.9 MiB 0.0 MiB 1 del ds, dl
Filename: /home/haoyu.tang/uim_se/test_datapipes.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
106 716.9 MiB 716.9 MiB 1 @profile
107 def dataset2(num_batch):
108 716.9 MiB 0.0 MiB 1 ds = MyDataset2(num_batch)
109 716.9 MiB 0.0 MiB 1 dl = DataLoader(ds)
110 716.9 MiB 0.0 MiB 1001 for i, batch in enumerate(dl):
111 716.9 MiB 0.0 MiB 1000 pass
112 716.9 MiB 0.0 MiB 1 pass
113 716.9 MiB 0.0 MiB 1 del ds, dl
Filename: /home/haoyu.tang/uim_se/test_datapipes.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
88 716.9 MiB 716.9 MiB 1 @profile
89 def datapipe(num_batch):
90 717.0 MiB 0.0 MiB 1 dp = build_dp2(num_batch).map(add_audio3)
91 721.6 MiB 4.6 MiB 1 dl = DataLoader2(dp)
92 2254.1 MiB 1532.6 MiB 5001 for i, batch in enumerate(dl):
93 2254.1 MiB 0.0 MiB 5000 pass
94 2254.1 MiB 0.0 MiB 1 pass
95 2252.1 MiB -2.0 MiB 1 del dp, dl
Filename: /home/haoyu.tang/uim_se/test_datapipes.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
97 2251.5 MiB 2251.5 MiB 1 @profile
98 def dataset1(num_batch):
99 2251.5 MiB 0.0 MiB 1 ds = MyDataset1(num_batch)
100 2251.5 MiB 0.0 MiB 1 dl = DataLoader(ds)
101 2251.5 MiB -7642068.4 MiB 5001 for i, batch in enumerate(dl):
102 2251.5 MiB -7640538.2 MiB 5000 pass
103 721.3 MiB -1530.2 MiB 1 pass
104 721.3 MiB 0.0 MiB 1 del ds, dl
Filename: /home/haoyu.tang/uim_se/test_datapipes.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
106 721.3 MiB 721.3 MiB 1 @profile
107 def dataset2(num_batch):
108 721.3 MiB 0.0 MiB 1 ds = MyDataset2(num_batch)
109 721.3 MiB 0.0 MiB 1 dl = DataLoader(ds)
110 721.3 MiB 0.0 MiB 5001 for i, batch in enumerate(dl):
111 721.3 MiB 0.0 MiB 5000 pass
112 721.3 MiB 0.0 MiB 1 pass
113 721.3 MiB 0.0 MiB 1 del ds, dl
Versions
torch version: 2.0.0
torchdata version: 0.6.0
It is clear that is pasing the dict of tensor memory will leak but list of tensor will not.
I used dict of tensor in my model training, and I found the training faied multiple times all since of memory leak. And I tried to used Tensordict(https://pytorch.org/rl/tensordict/), but it cannot contains the string. I need string during my datapipes passing (str to tensor encode in one of datapipes).
I'm also using dictionaries and see a memory leak. I'm highlighting a different issue but I'm seeing a small increase in usage over time as well:
@andrew-bydlon I got a temp method to fix the issue. I split my original dict as two dict, one only contains the tensor another without tensor.
Same problem with python dict and tensordict.