lucidrains/byol-pytorch

GPU Memory Usage Extremely High

ClemensSchwarke opened this issue · 0 comments

Hi,
I am not an expert in pytorch and would appreciate some help with understanding my vram utilization. When exchanging

images = torch.cat((image_one, image_two), dim = 0)

with

images = torch.rand(128, 3, 256, 256, device=image_a.device),

the needed vram explodes (20GB instead of 2GB), when


is executed. I can't find an explanation for this behavior :/ It is relevant to me, because it also happens in my use case that obviously doesn't include a random tensor but is more tricky to explain.

Many thanks in advance :)