STDPLearner never frees the created tensors even when out of scope, causing torch.OutOfMemoryError
Opened this issue · 2 comments
Issue type
- Bug Report
- Feature Request
- Help wanted
- Other
SpikingJelly version
0.0.0.0.15
Description
I am performing an hyperparameter search, using a grid search, which implies creating a different model every loop and training it. The instantiation and training of the model are performed inside a function that is called with different configuration parameters every run.
When finishing every iteration of the loop, the model and all created variables, including optimizer, dataset, and STDPLearners, go out of scope, which should delete them. However, the current behaviour is that tensors created by STDPLearner are never deleted, filling up the memory until the program crashes. The MWE crashes after only 3 iterations when the model includes a layer.Conv2D
, and the problem still exists when the model only uses layer.Linear
layers, although to a lesser extent.
Happens with both torch and cupy backends. Explicitly calling the garbage collector does not help. When training using gradient descent, as shown in the minimal example by swapping train_model_stdp
with the train_model_gd
function, this problem doesn't show up, linking it with the STDPLearner
class.
Minimal code to reproduce the error/bug
This is the basic training code that causes the issue. A single epoch per run is used because the problem only shows up when starting new runs. Includes a debugging function to see how many tensors are there.
import itertools
import gc
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as D
from spikingjelly.activation_based import layer, neuron, functional, learning
from tqdm.auto import tqdm
def f_weight(x):
return torch.clamp(x, -1., 1.)
class ConvMNIST(nn.Module):
def __init__(self, tau: float):
super().__init__()
self.network = nn.Sequential(
layer.Conv2d(1, 32, kernel_size=3, stride=1, bias=False),
neuron.IFNode(),
layer.AvgPool2d(2, 2),
layer.Flatten(),
layer.Linear(13*13*32, 10, bias=False),
neuron.LIFNode(tau=tau)
)
functional.set_step_mode(self.network, 'm')
functional.set_backend(self.network, 'torch')
def forward(self,x):
return self.network(x)
def get_stdp_learners(self):
stdp_learners = []
for i in range(len(self.network)):
if isinstance(self.network[i], (layer.Conv2d, layer.Linear)):
stdp_learners.append(
learning.STDPLearner(step_mode='m', synapse=self.network[i], sn=self.network[i+1], tau_pre=2.0, tau_post=2.0,
f_pre=f_weight, f_post=f_weight)
)
return stdp_learners
def train_model_stdp(
model: nn.Module,
learning_rate: float,
train_data: D.Dataset,
device="cuda"):
model.to(device)
train_data = D.Subset(train_data, range(400))
train_data = D.DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
stdp_learners = model.get_stdp_learners()
print_tensors()
for epoch in range(1, 2):
model.train()
for l in stdp_learners:
l.enable()
for batch, targets in (
pbar := tqdm(train_data, dynamic_ncols=True, desc=f"Epoch {epoch}")
):
batch = batch.to(device)
targets = targets.to(device)
output = model(batch.expand(100, -1, -1, -1, -1))
optimizer.zero_grad()
for l in stdp_learners:
l.step()
optimizer.step()
functional.reset_net(model)
for l in stdp_learners:
l.reset()
print_tensors()
def print_tensors(show_tensors=False):
tensor_count = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
if show_tensors: print(type(obj), obj.size())
tensor_count += 1
except: pass
print(f"{tensor_count=}")
def run(tau: float, learning_rate: float):
model = ConvMNIST(tau)
train_dataset = torchvision.datasets.MNIST(
root='./datasets',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
train_model_stdp(model=model, learning_rate=learning_rate, train_data=train_dataset)
if __name__ == '__main__':
search_space = {
"tau": [2.0, 2.5, 3.0, 5.0, 10.0, 100.0],
"learning_rate": [0.1, 0.01, 1e-3, 1e-4, 1e-5]
}
n_runs = 0
for tau, learning_rate in itertools.product(*search_space.values()):
n_runs += 1
print(f"Starting run {n_runs} with params: {tau=} {learning_rate=}")
run(tau=tau, learning_rate=learning_rate)
print(gc.collect())
print_tensors()
By swapping the train_model_stdp
function with the following one, which simulates training with gradient descent, the problem disappears
def train_model_gd(
model: nn.Module,
learning_rate: float,
train_data: D.Dataset,
device="cuda"):
model.to(device)
train_data = D.Subset(train_data, range(400))
train_data = D.DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True, num_workers=4)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
print_tensors()
for epoch in range(1, 2):
model.train()
for batch, targets in (
pbar := tqdm(train_data, dynamic_ncols=True, desc=f"Epoch {epoch}")
):
batch = batch.to(device)
targets = targets.to(device)
output = model(batch.expand(100, -1, -1, -1, -1))
mean_activations = output.mean(0)
batch_loss = nn.functional.cross_entropy(mean_activations, targets)
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
functional.reset_net(model)
print_tensors()
Thanks for your issue!
solution
The out-of-memory problem you've encountered can be addressed by running STDPLearner.step()
within a torch.no_grad()
context. You may modify your STDP code as follows:
...
optimizer.zero_grad()
for l in stdp_learners:
with torch.no_grad():
l.step()
optimizer.step()
...
a deep dive
STDPLearner.step(on_grad=True)
updates the weight by adding w.grad
, where w
is the weight tensor. After optimizer.step()
is called, STDPLearner.step()
in a torch.no_grad()
context ensures that its internal computations are excluded from the PyTorch computational graph. Without torch.no_grad()
, the operations in STDPLearner.step()
will be part of the graph; as no backward()
is called in your STDP code, the computational graph will never be freed, leading to the out-of-memory error.
Thank you, this seems to prevent the memory usage from growing uncontrollably, which is good enough for me. However, runs following the first one still take up more memory than a single run would.
According to pytorch docs, the reference to the computational graph is held by the resulting tensor of an operation, meaning that if that tensor goes out of scope the whole graph is freed. It appears that there's still some reference to the STDPLearner
tensors that is accesible from outside the function, probably at the module level, otherwise they would be collected.