IBM/aihwkit

Backward pass not ideal when using InferenceRPUConfig

HeatPhoenix opened this issue · 4 comments

Description

From my understanding of InferenceRPUConfig, training using this RPUConfig should create a situation where the forward pass is noisy, but the backward pass is ideal. As such, shouldn't rpu_config.backward.is_perfect = True not have any effect? As by definition of InferenceRPUConfig the backward should be perfect by default.

Instead, with and without rpu_config.backward.is_perfect set to True produces different results.

backward.is_perfect = True Inference plot
image
vs.
default settings on InferenceRPUConfig
image

And loss development:
backwards perfect
image
default
image

How to reproduce

I am training a simple regression network like so:

neurons = 128
model = nn.Sequential(
          nn.Linear(6,neurons),
          nn.Softplus(),
          nn.Linear(neurons,neurons),
          nn.Softplus(),  
          nn.Linear(neurons,neurons),
          nn.Softplus(),  
          nn.Linear(neurons,3),
        )

With the following RPUConfig:

# Define a single-layer network, using inference/hardware-aware training tile
rpu_config = InferenceRPUConfig()
rpu_config.forward.out_res = -1.0  # Turn off (output) ADC discretization.
rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT
rpu_config.forward.w_noise = 0.02  # Short-term w-noise.
# Inference noise model.
rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
# drift compensation
rpu_config.drift_compensation = GlobalDriftCompensation()
rpu_config.backward.is_perfect = True #or commented out for default behavior

Training is done with AnalogAdam for 100 epochs to create the above plots. The neural network is confirmed to work in fully digital (also when both forward and backward is_perfect).

Expected behavior

I would expect backward.is_perfect to have no effect in Inference-only setups. But instead, it has a very significant effect. Is this writing noise? Reading noise from getting the activations? The documentation and communication on GitHub states things like "the backward pass and update is thought to be perfect but noise is injected in the forward pass only; see InferenceRPUConfig"

Other information

  • Pytorch version: 2.1.0
  • Package version: 0.9.0
  • OS: Debian

That's indeed surprising. You are right it should not have any effect. Thanks for raising the issue.

Hi.
I ran this sample script:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import aihwkit
from aihwkit.inference.compensation.drift import GlobalDriftCompensation
from aihwkit.inference.noise.pcm import PCMLikeNoiseModel
from aihwkit.nn.conversion import convert_to_analog
from aihwkit.optim.analog_optimizer import AnalogSGD
from aihwkit.simulator.configs.configs import InferenceRPUConfig
from aihwkit.simulator.parameters.enums import WeightNoiseType


def instantiate_model(neurons=128):
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)
        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            # output = F.log_softmax(x, dim=1)
            # return output
            return x
    return Net()

def gen_rpu_config(is_perfect):
    rpu_config = InferenceRPUConfig()
    rpu_config.forward.out_res = -1.0  # Turn off (output) ADC discretization.
    rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT
    rpu_config.forward.w_noise = 0.02  # Short-term w-noise.
    # Inference noise model.
    rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
    # Drift compensation
    rpu_config.drift_compensation = GlobalDriftCompensation()
    rpu_config.backward.is_perfect = is_perfect
    return rpu_config

def get_data():
    transform= torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset = torchvision.datasets.MNIST('data', train=True, download=True,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)
    return train_loader
if __name__ == '__main__':
    cuda = True
    model = instantiate_model()
    analog_models = [
        convert_to_analog(model, gen_rpu_config(is_perfect=True)),
        convert_to_analog(model, gen_rpu_config(is_perfect=False)),
    ]
    optimizers = []
    for analog_model in analog_models:
        if cuda:
            analog_model.cuda()
        analog_model.train()
        optimizer = AnalogSGD(analog_model.parameters(), lr=0.001)
        optimizers.append(optimizer)
    batch_size = 64
    n_iters = -1
    logging_interval = 25
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(get_data()):
        if cuda:
            data, target = data.cuda(), target.cuda()
        gradients_mean = []
        gradients_std = []
        losses = []
        for analog_model, optimizer in zip(analog_models, optimizers):
            optimizer.zero_grad()
            model_output = analog_model(data)
            loss = criterion(model_output, target)
            loss.backward()
            optimizer.step()
            gradients = torch.tensor([])
            for p in analog_model.parameters():
                if p.grad is not None:
                    p_grad = p.grad.detach().cpu().data.flatten()
                    gradients = torch.cat([gradients, p_grad], dim=-1)
            gradients_mean.append(gradients.mean())
            gradients_std.append(gradients.std())
            losses.append(loss.item())
        if batch_idx % logging_interval == 0:
            print('Iteration {} mean: '.format(batch_idx), gradients_mean, ', std: ', gradients_std, 'loss: ', losses)
        if batch_idx == n_iters - 1:
            break

and this is the log I get:

Iteration 0 mean:  [tensor(0.0003), tensor(0.0002)] , std:  [tensor(0.0017), tensor(0.0023)] loss:  [2.440281629562378, 2.4453179836273193]
Iteration 25 mean:  [tensor(0.0003), tensor(0.0002)] , std:  [tensor(0.0015), tensor(0.0024)] loss:  [2.4614646434783936, 2.4668071269989014]
Iteration 50 mean:  [tensor(0.0002), tensor(0.0002)] , std:  [tensor(0.0015), tensor(0.0023)] loss:  [2.465376615524292, 2.409703016281128]
Iteration 75 mean:  [tensor(0.0003), tensor(0.0002)] , std:  [tensor(0.0014), tensor(0.0022)] loss:  [2.433210611343384, 2.42989444732666]
Iteration 100 mean:  [tensor(0.0002), tensor(0.0001)] , std:  [tensor(0.0014), tensor(0.0022)] loss:  [2.3848493099212646, 2.417853355407715]
Iteration 125 mean:  [tensor(1.3142e-05), tensor(0.0001)] , std:  [tensor(0.0014), tensor(0.0021)] loss:  [2.3443965911865234, 2.314314365386963]
Iteration 150 mean:  [tensor(0.0001), tensor(8.9972e-05)] , std:  [tensor(0.0015), tensor(0.0023)] loss:  [2.37026309967041, 2.2686352729797363]
Iteration 175 mean:  [tensor(1.2524e-05), tensor(5.3390e-05)] , std:  [tensor(0.0014), tensor(0.0021)] loss:  [2.365387201309204, 2.37935733795166]
Iteration 200 mean:  [tensor(8.2563e-05), tensor(-4.6383e-06)] , std:  [tensor(0.0014), tensor(0.0023)] loss:  [2.332214117050171, 2.349396228790283]
Iteration 225 mean:  [tensor(0.0001), tensor(0.0002)] , std:  [tensor(0.0015), tensor(0.0023)] loss:  [2.3589439392089844, 2.356438398361206]

The first entry in the loss is is_perfect=True and the second is False. I don't see a clear difference.

Are you sure you are setting the rpu_config? When you pass None to convert_to_analog, another default Tile class will be used.

Hi @HeatPhoenix! Did you have the chance to try out and look at the script @jubueche suggested?

Give us some feedback when you can as we can try to help you more if any problems arise. Thank you!

Closing as no response from @HeatPhoenix. If this issue still persist, please re-open this or open a new issue. Thanks!