pomonam/jax-influence

Implementation of PBRF

Hzfinfdu opened this issue ยท 3 comments

Thanks for your wonderful work If Influence Functions are the Answer, Then What is the Question?.

I had trouble reimplementing PBRF on MNIST+MLP since I found that the initial loss can be quite small and unstable in optimization. Besides, I am not quite sure whether my implementation is true. Could you please help me out or post the code of PBRF?

Thanks in advance!

Hi @Hzfinfdu, thank you for your interest in our paper! I am planning to release a new repo soon; but in the meantime, I would be happy to help debug the issue. You can post it here, or send me an email at jbae [at] cs.toronto.edu.

Hi @pomonam , much appreciation for your quick response and patience. But debugging can be annoying and I don't wanna bother you with my janky code lol. Looking forward to your new repo!

Thanks again for your reply!

Hello,

I have encountered the same issues as @Hzfinfdu while implementing PBRF on MNIST. Regardless of how I sample or train, the obtained PBRF change seems to lack a consistent pattern and doesn't converge. I would like to seek assistance regarding whether there might be issues with my code implementation. I have provided my code below, and I would greatly appreciate it if you could help me inspect it for potential bugs.

Thanks for considering my request.

import torch
import torch.nn as nn
import numpy as np
import random
import datasets
import torch.optim as optim
from tqdm import tqdm
import random
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MNIST_dataset = datasets.load_dataset('mnist')
MNIST_dataset['test'], MNIST_dataset['validation'] = MNIST_dataset['test'].train_test_split(test_size=0.5).values()


# Define a Multi-Layer Perceptron (MLP) model.
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, output_size)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# Cross-entropy loss function.
ce_loss = nn.CrossEntropyLoss(reduction="none")
ce_loss_mean = nn.CrossEntropyLoss(reduction="mean")


def Bregman_divergence_for_crossentropy(p, q, label):
    fp = ce_loss(p, label)
    fq = ce_loss(q, label)
    gradient = torch.softmax(q, dim=1) - torch.nn.functional.one_hot(
        label, num_classes=10
    )
    p_minus_q = p - q
    return (fp - fq - torch.sum(gradient * p_minus_q, dim=1)).sum()


def PBO(output, predict_s, target):
    # Calculate Bregman divergence for cross-entropy loss.
    # the parameter output is the output of the current model and predict_s is the output of the original model.
    # target is the ground truth label.
    BregDiv = Bregman_divergence_for_crossentropy(output, predict_s, target)
    # I placed the data source_example, which is needed to compute the PBRF Change, 
	# at the beginning of each batch in the collate_fn. 
	# Therefore, here, I use output[0] and target[0] to calculate the cross-entropy of the current model with respect to source_example.
    ce = ce_loss_mean(output[0], target[0])
    loss = BregDiv - ce
    return loss


def get_origin_model():
    origin_model = MLP(784, 512, 256, 10)
    # Suppose the original model is trained and saved in the folder MLP_ckpts.
    origin_model.load_state_dict(torch.load("./MLP_ckpts/best_model.pt")) 
    origin_model.to(device)
    return origin_model


def train(model, train_loader, optimizer, criterion, damping):
    # Create a copy of the model's parameters.
    param_ses = [param.detach().clone() for param in model.parameters()]
    model.train()
    train_loss = []
    train_total = 0
    bar = tqdm(range(5 * len(train_loader)))
    for i in range(5):
        for batch_idx, batch in enumerate(train_loader):
            data, predict_s, target = (
                batch["image"].to(device),
                batch["predict_s"].to(device),
                batch["label"].to(device),
            )
            optimizer.zero_grad()
            output = model(data.view(-1, 784))
            loss = criterion(output, predict_s, target)
            for param, param_s in zip(model.parameters(), param_ses):
                loss = loss + (param - param_s).norm() ** 2 * damping * 0.5
            loss.backward()
            optimizer.step()
            train_total += target.size(0)
            train_loss.append(loss.item())
            bar.update(1)
            bar.set_description(
                "Epoch: {} | Train Loss: {:.8f}".format(i + 1, loss.item())
            )

    return model


def get_PBRF_change(source_example, query_example, train_loader):
    model = get_origin_model()
    optimizer = optim.Adam(model.parameters(), lr=1e-6, weight_decay=0)
    origin_logits = model(torch.tensor(query_example["image"], device=device).view(784))
    origin_loss = ce_loss(
        origin_logits, torch.tensor(query_example["label"], device=device)
    )
    predicted_s_source = (
        model(torch.tensor(source_example["image"], device=device).view(-1, 784))
        .detach()
        .cpu()
    )

    # redefine dataloader
    def collate_fn(examples):
        return {
            "image": torch.cat(
                [torch.tensor(source_example["image"])]
                + [torch.tensor(example["image"]) for example in examples],
                dim=0,
            ),
            "label": torch.tensor(
                [source_example["label"]] + [example["label"] for example in examples]
            ),
            "predict_s": torch.cat(
                [predicted_s_source]
                + [torch.tensor(example["predict_s"]) for example in examples],
                dim=0,
            ),
        }

    train_loader = DataLoader(
        MNIST_dataset["train"], batch_size=511, shuffle=True, collate_fn=collate_fn
    )
    trained_model = train(model, train_loader, optimizer, PBO, damping=0.512)
    PBRF_logits = trained_model(
        torch.tensor(query_example["image"], device=device).view(-1, 784)
    )
    PBRF_loss = ce_loss(
        PBRF_logits.squeeze(), torch.tensor(query_example["label"], device=device)
    )
    PBRF_change = PBRF_loss - origin_loss
    return PBRF_change


def main():
    # Add the output of the original model to the MNIST dataset as a new column "predict_s".
    if "predict_s" not in MNIST_dataset["train"].column_names:
        model = get_origin_model()
        MNIST_dataset["train"] = MNIST_dataset["train"].map(
            lambda example: {
                "predict_s": model(
                    torch.tensor(example["image"], device=device).view(-1, 784)
                )
            }
        )
    random_numbers = random.sample(range(1, 500), 10)
    for i in random_numbers:
        source_example = MNIST_dataset["train"][i]
        query_example = MNIST_dataset["test"][i]
        P = get_PBRF_change(source_example, query_example)


if __name__ == "__main__":
    main()