Memory leak in PyTorch interaction
lorenzoh opened this issue · 7 comments
In trying to use PyTorch with PyCallChainRules.jl, I discovered a memory leak (rejuvyesh/PyCallChainRules.jl#24) that seems to be in its interaction with DLPack.jl.
I've tried to reduce the error to a MWE that leads to an Out-Of-Memory error because GPU memory isn't being freed correctly.
The following loads a model and a CuArray
, uses DLPack.jl to share the CuArray
with PyTorch and then runs the array through the model repeatedly. The GPU memory increases linearly until the OOM error.
using CUDA, PyCall, DLPack
dlpack = pyimport("torch.utils.dlpack")
torch = pyimport("torch")
pytorch_from_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
function memoryused()
info = CUDA.MemoryInfo()
return 1 - (info.free_bytes / info.total_bytes)
end
pymodel = torch.hub.load("pytorch/vision", "resnet18").to("cuda")
xs = cu(randn(Float32, 224, 224, 3, 16))
xs_shared = DLPack.share(xs, PyObject, pytorch_from_dlpack)
usages = [memoryused()]
numrefs = [length(DLPack.SHARES_POOL)]
for i in 1:100
pymodel(xs_shared)
push!(usages, memoryused())
push!(numrefs, length(DLPack.SHARES_POOL))
end
The memory usage:
I also looked into DLPack.SHARES_POOL
to see if there are references piling up there, but the number of references stays constant (2).
The GPU memory stays exhausted until I run torch.cuda.empty_cache()
, but this only works after calling GC.gc()
, leading me to believe that some references are being held on the Julia side that prevent the memory from being cleared by PyTorch.
Any help with this or pointers for places to look in the code base would be much appreciated! I'm happy to run further tests that could help diagnose the problem.
Any progress on this?
Haven't had any luck finding the issue
Sorry I had missed this. I will try to look into it soon.
Would you be able to provide more details on the python packages and versions needed to reproduce this?
@pabloferz I think only pytorch is needed for this MWE. I have seen this on all recent torch versions 1.10, 1.11 and 1.12.
Ok, I verified and everything seems to work as it should. What you are seeing is the same as in JuliaPy/PyCall.jl#436 and JuliaPy/PyCall.jl#529. That is, there's no memory leak but Julia has no idea that it should garbage collect more frequently the PyObject
s that result from the calls to pymodel
.
What I would do here is the following:
using CUDA, PyCall, DLPack
dlpack = pyimport("torch.utils.dlpack")
torch = pyimport("torch")
pytorch_from_dlpack(x) = @pycall dlpack.from_dlpack(x)::PyObject
function memoryused()
info = CUDA.MemoryInfo()
return 1 - (info.free_bytes / info.total_bytes)
end
pymodel = torch.hub.load("pytorch/vision", "resnet18").to("cuda")
apply_model = x -> (y = pymodel(x); GC.gc(false); y) # garbage collect, but only recent "young" objects
xs = cu(randn(Float32, 224, 224, 3, 16))
xs_shared = DLPack.share(xs, PyObject, pytorch_from_dlpack)
usages = [memoryused()]
for i in 1:100
apply_model(xs_shared)
push!(usages, memoryused())
end
# Alternatively:
for i in 1:100
pymodel(xs_shared)
GC.gc(false)
push!(usages, memoryused())
end
Confirmed that this is true for PythonCall
as well. Would be useful to document this prominently somewhere!