JaxGaussianProcesses/GPJax

bug: Memory leak when runing the example code with <1k points

Closed this issue · 6 comments

Bug Report

GPJax version:

0.7.1

Current behavior:

Memory leak when calling gpx.fit in the "Simple example" code.

Expected behavior:

Example to run correctly.

Steps to reproduce:

Copy code from front page of gh "Simple example" and set n=1001`.

Other information:

Looking at the stacktrace it seems like the issue is coming from a side effct inside cola. Could this be related to a heuristic algortihm swap for operations greater than 1k points?

It actually seems like any standard dense GPR with more than 1k points is causing a memory leak in my environment? Was there a recent change on the cola side that could cause this?

Thanks @MDCHAMP. Yes, it is my understanding the default (auto) algorithm for CoLA changes >1000 datapoints so this will likely be the issue. What version of CoLA do you have installed in your environment?

I'm using cola=0.0.1 as it came with gpjax. I tried it with 0.0.5 but it doesn't seem that gpjax currently supports that version?

Thanks for the info @MDCHAMP. I have bumped CoLA tov0.0.5 in this PR #405. There's certainly a memory leak with default CoLA settings on >1000 datapoints given the docs fail to build on such examples. Tests pass fine otherwise. Will have a little dig around the traceback and open an issue over on CoLA for this shortly.

Ah that makes a lot of sense then. A current workaround is to not jit the objective function in the call to gpx.fit

Ooh that interesting. Maybe we are slightly at fault on that one. As briefly mentioned in #402, jitting objective functions before passing to fit seems little dodgy, as it is a class. I think removing the ability to do this, and rather only allowing to jit objective.step would be a better way to go. But for the fit abstraction, you could still take in your objective as normal (just no jit now). But I still think there's something interesting going on with CoLA - as on my M1 machine 999 datapoints seems fast (a few microseconds) and 1001 datapoints seems slow (3 mins) for jit(gpx.ConjugateMLL(negative=True).step). I just need to find what the issue is. But I reckon if we used the algorithm=Cholesky() for all of CoLA we will probably have a fix for now. (Edit: bencmarks are for the #405 branch).