bug: Memory leak when runing the example code with <1k points
Closed this issue · 6 comments
Bug Report
GPJax version:
0.7.1
Current behavior:
Memory leak when calling gpx.fit
in the "Simple example" code.
Expected behavior:
Example to run correctly.
Steps to reproduce:
Copy code from front page of gh "Simple example" and set n=
1001`.
Other information:
Looking at the stacktrace it seems like the issue is coming from a side effct inside cola. Could this be related to a heuristic algortihm swap for operations greater than 1k points?
It actually seems like any standard dense GPR with more than 1k points is causing a memory leak in my environment? Was there a recent change on the cola side that could cause this?
Thanks @MDCHAMP. Yes, it is my understanding the default (auto) algorithm for CoLA changes >1000
datapoints so this will likely be the issue. What version of CoLA do you have installed in your environment?
I'm using cola=0.0.1
as it came with gpjax. I tried it with 0.0.5 but it doesn't seem that gpjax currently supports that version?
Thanks for the info @MDCHAMP. I have bumped CoLA tov0.0.5
in this PR #405. There's certainly a memory leak with default CoLA settings on >1000 datapoints given the docs fail to build on such examples. Tests pass fine otherwise. Will have a little dig around the traceback and open an issue over on CoLA for this shortly.
Ah that makes a lot of sense then. A current workaround is to not jit the objective function in the call to gpx.fit
Ooh that interesting. Maybe we are slightly at fault on that one. As briefly mentioned in #402, jitting objective functions before passing to fit seems little dodgy, as it is a class. I think removing the ability to do this, and rather only allowing to jit objective.step
would be a better way to go. But for the fit abstraction, you could still take in your objective as normal (just no jit now). But I still think there's something interesting going on with CoLA - as on my M1 machine 999 datapoints seems fast (a few microseconds) and 1001 datapoints seems slow (3 mins) for jit(gpx.ConjugateMLL(negative=True).step)
. I just need to find what the issue is. But I reckon if we used the algorithm=Cholesky()
for all of CoLA we will probably have a fix for now. (Edit: bencmarks are for the #405 branch).