Memory Leak with gp.sample() if part of an imported function sampled by emcee
gmduvvuri opened this issue · 5 comments
Hello,
I realize this is a very strange set of conditions but I've created a minimal working example. With one file named jax_memory_leak.py defining a function that uses the GaussianProcess:
import jax
import numpy as np
from jax import numpy as jnp
from tinygp import kernels, GaussianProcess
from memory_profiler import profile
@profile
def make_gp():
x_arr = jnp.sort(np.random.default_rng(1).uniform(0, 10, 100))
y_arr = np.random.default_rng(1).uniform(0, 10, 100)
kernel = 1.0 * kernels.Matern32(scale=0.5)
gp = GaussianProcess(kernel, x_arr, mean=0.0)
_, cond_gp = gp.condition(y_arr, x_arr)
return cond_gp.sample(jax.random.PRNGKey(1))
and then another file jax_test.py that imports this function and tries to sample with emcee:
from jax_memory_leak import make_gp
import emcee
import gc
def ln_likelihood(param):
arr = make_gp()
gc.collect()
return 0.0
sampler = emcee.EnsembleSampler(4, 1, ln_likelihood)
init_pos = [[i] for i in range(4)]
sampler.run_mcmc(init_pos, 10)
The memory profiler run with "mprof run jax_test.py" shows that every iteration of the emcee sampling adds a couple of MiB to RAM. I have tried setting a bunch of things to None, using a combination of del and gc.collect(), but the problem persists. My specific problem cannot use the likelihood operations of tinygp because I am calculating the likelihood using an integral of the product of the GP and another function, and long runs result in prohibitive RAM costs. I've tried digging through jax documentation to figure this out but can't see a path to a solution. Any help would be appreciated.
I have tried running this on Ubuntu on Windows through the Windows Subsystem for Linux and on MacOSX, both with fully updated tinygp and jax. If you need anymore specific device information I can provide it.
I'm not at my computer, but I'd recommend removing the memory profile decorator and replacing it with a jax.jit
decorator. This will at least reduce the memory usage, and I expect it will fix your problem. Let me know!
(presumably in your real code you're passing in parameters? In this demo the function will always return the same value....)
In my real scenario, which is significantly more complicated, I am passing a number of arguments and have been unable to express the function in a jax-pure way so every time I use the @jax.jit decorator I get results for the first use of the function and not the following sampler iterations.
The function takes as arguments: three float parameters sampled by emcee, a float that changes depending on the parameters, and two arrays that change values but not shape depending on the parameters,
In my real scenario, which is significantly more complicated, I am passing a number of arguments and have been unable to express the function in a jax-pure way so every time I use the jax.jit decorator I get results for the first use of the function and not the following sampler iterations.
In this case, this is never going to be performant! I'm sure it's possible to set your problem up as necessary. Feel free to drop more details about why you're not able to get it working here or over email.
The main barrier has been that my model is a sum of 2 components, each of which can have a different number of parameters depending on the version of each component, and I was using array indexing to split the parameter array based on a likelihood_arg specifying the break between both. I've added a bunch of extra likelihood args to turn the parameter array into a dict instead and now the function is jit-wrappable. This got rid of the memory leak and sped things up by a factor of ~200, so I think you can close this issue.
Based on the MWE though, my guess is that there still is some memory allocated by the jax.random.normal for a non-jitted function that doesn't get cleaned up by the garbage collection, so I'll leave the closing decision up to you. I will also add that I had a workaround for the memory leak by enclosing the gp.sample step in a multiprocessing.Pool(1) that I could terminate, but this prevented me from using the emcee Pool.