fangwei123456/spikingjelly

STDPLearner never frees the created tensors even when out of scope, causing torch.OutOfMemoryError

Opened this issue · 2 comments

Issue type

  • Bug Report
  • Feature Request
  • Help wanted
  • Other

SpikingJelly version

0.0.0.0.15

Description

I am performing an hyperparameter search, using a grid search, which implies creating a different model every loop and training it. The instantiation and training of the model are performed inside a function that is called with different configuration parameters every run.

When finishing every iteration of the loop, the model and all created variables, including optimizer, dataset, and STDPLearners, go out of scope, which should delete them. However, the current behaviour is that tensors created by STDPLearner are never deleted, filling up the memory until the program crashes. The MWE crashes after only 3 iterations when the model includes a layer.Conv2D, and the problem still exists when the model only uses layer.Linear layers, although to a lesser extent.

Happens with both torch and cupy backends. Explicitly calling the garbage collector does not help. When training using gradient descent, as shown in the minimal example by swapping train_model_stdp with the train_model_gd function, this problem doesn't show up, linking it with the STDPLearner class.

Minimal code to reproduce the error/bug

This is the basic training code that causes the issue. A single epoch per run is used because the problem only shows up when starting new runs. Includes a debugging function to see how many tensors are there.

import itertools
import gc

import torch
import torch.nn as nn
import torchvision
import torch.utils.data as D
from spikingjelly.activation_based import layer, neuron, functional, learning
from tqdm.auto import tqdm


def f_weight(x):
    return torch.clamp(x, -1., 1.)


class ConvMNIST(nn.Module):
    def __init__(self, tau: float):
        super().__init__()

        self.network = nn.Sequential(
            layer.Conv2d(1, 32, kernel_size=3, stride=1, bias=False),
            neuron.IFNode(),
            layer.AvgPool2d(2, 2),

            layer.Flatten(),
            layer.Linear(13*13*32, 10, bias=False),
            neuron.LIFNode(tau=tau)
        )

        functional.set_step_mode(self.network, 'm')
        functional.set_backend(self.network, 'torch')

    def forward(self,x):
        return self.network(x)
    
    def get_stdp_learners(self):
        stdp_learners = []

        for i in range(len(self.network)):
            if isinstance(self.network[i], (layer.Conv2d, layer.Linear)):
                stdp_learners.append(
                    learning.STDPLearner(step_mode='m', synapse=self.network[i], sn=self.network[i+1], tau_pre=2.0, tau_post=2.0,
                                        f_pre=f_weight, f_post=f_weight)
                )
        
        return stdp_learners


def train_model_stdp(
        model: nn.Module,
        learning_rate: float,
        train_data: D.Dataset,
        device="cuda"):

    model.to(device)

    train_data = D.Subset(train_data, range(400))
    train_data = D.DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True, num_workers=4)

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)

    stdp_learners = model.get_stdp_learners()

    print_tensors()
    for epoch in range(1, 2):
        model.train()
        for l in stdp_learners:
            l.enable()

        for batch, targets in (
            pbar := tqdm(train_data, dynamic_ncols=True, desc=f"Epoch {epoch}")
        ):

            batch = batch.to(device)
            targets = targets.to(device)

            output = model(batch.expand(100, -1, -1, -1, -1))

            optimizer.zero_grad()
            for l in stdp_learners:
                l.step()
            optimizer.step()

            functional.reset_net(model)
            for l in stdp_learners:
                l.reset()
    
    print_tensors()

def print_tensors(show_tensors=False):
    tensor_count = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                if show_tensors: print(type(obj), obj.size())
                tensor_count += 1
        except: pass
    print(f"{tensor_count=}")

def run(tau: float, learning_rate: float):
    model = ConvMNIST(tau)

    train_dataset = torchvision.datasets.MNIST(
        root='./datasets',
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True
    )

    train_model_stdp(model=model, learning_rate=learning_rate, train_data=train_dataset)

if __name__ == '__main__':
    search_space = {
        "tau": [2.0, 2.5, 3.0, 5.0, 10.0, 100.0],
        "learning_rate": [0.1, 0.01, 1e-3, 1e-4, 1e-5]
    }

    n_runs = 0
    for tau, learning_rate in itertools.product(*search_space.values()):
        n_runs += 1
        print(f"Starting run {n_runs} with params: {tau=} {learning_rate=}")
        run(tau=tau, learning_rate=learning_rate)
        print(gc.collect())
        print_tensors()

By swapping the train_model_stdp function with the following one, which simulates training with gradient descent, the problem disappears

def train_model_gd(
        model: nn.Module,
        learning_rate: float,
        train_data: D.Dataset,
        device="cuda"):

    model.to(device)

    train_data = D.Subset(train_data, range(400))
    train_data = D.DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True, num_workers=4)

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)

    print_tensors()
    for epoch in range(1, 2):
        model.train()

        for batch, targets in (
            pbar := tqdm(train_data, dynamic_ncols=True, desc=f"Epoch {epoch}")
        ):

            batch = batch.to(device)
            targets = targets.to(device)

            output = model(batch.expand(100, -1, -1, -1, -1))
            mean_activations = output.mean(0)
            batch_loss = nn.functional.cross_entropy(mean_activations, targets)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            functional.reset_net(model)
    
    print_tensors()

Thanks for your issue!

solution

The out-of-memory problem you've encountered can be addressed by running STDPLearner.step() within a torch.no_grad() context. You may modify your STDP code as follows:

...
optimizer.zero_grad()
for l in stdp_learners:
    with torch.no_grad():
        l.step()
optimizer.step()
...

a deep dive

STDPLearner.step(on_grad=True) updates the weight by adding $$-\Delta w$$ to w.grad, where w is the weight tensor. After optimizer.step() is called, $$-\Delta w$$ is subtracted from the weight, thus updating the weight as expected. Wrapping STDPLearner.step() in a torch.no_grad() context ensures that its internal computations are excluded from the PyTorch computational graph. Without torch.no_grad(), the operations in STDPLearner.step() will be part of the graph; as no backward() is called in your STDP code, the computational graph will never be freed, leading to the out-of-memory error.

Thank you, this seems to prevent the memory usage from growing uncontrollably, which is good enough for me. However, runs following the first one still take up more memory than a single run would.

According to pytorch docs, the reference to the computational graph is held by the resulting tensor of an operation, meaning that if that tensor goes out of scope the whole graph is freed. It appears that there's still some reference to the STDPLearner tensors that is accesible from outside the function, probably at the module level, otherwise they would be collected.