sbrunk/storch

Improve memory management

Opened this issue · 0 comments

We can't rely on the garbage collector to free native memory in time as it is too slow and unpredictable. This is especially true for GPU tensors because usually we want to utilize the constrained GPU memory as much as possible without allocating too much and crashing the process.

So we need a solution using reference counting and explicit deallocation, not only for single tensors, but for groups as well.
PointerScope (example) from JavaCPP is a working way to achieve that:

trainDL.foreach { (input, label) =>
  optimizer.zeroGrad()
  Using.resource(new PointerScope()) { p =>
    val pred = model(input.to(device))
    val loss = lossFn(pred, label.to(device))
    loss.backward()
    optimizer.step()
  }

This will deallocate all tensors created inside the Using block as soon as it is out of scope. One question is if we want to expose PointerScope.

If we want to keep i.e. the loss tensor for stats, we need to manually increase the reference counter of the native loss tensor. I'm not sure how well that would work and look like in practice, needs investigation.

scala-torch seems to be doing something similar with their ReferenceManager, but using implicits/context params which should be worth a look as well. See https://github.com/microsoft/scala_torch#memory-management