Koukyosyumei/AIJack

In aijack/attack/inversion/gan_attack.py, "from aijack.attack import GAN_Attack". I can not find GAN_Attack.

akyo1o opened this issue · 4 comments

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: [e.g. iOS]
  • Browser [e.g. chrome, safari]
  • Version [e.g. 22]

Smartphone (please complete the following information):

  • Device: [e.g. iPhone6]
  • OS: [e.g. iOS8.1]
  • Browser [e.g. stock browser, safari]
  • Version [e.g. 22]

Additional context
Add any other context about the problem here.

@akyo1o
Thank you for your interest! Please check #68 .

Some of the example code is old, so any contribution (like a Pull Request) is extremely welcome!

In this case, you can update the example of GAN attack as follows (it might not be complete):

  • import

old (line 10)

from aijack.attack import GAN_Attack

new

from aijack.attack import GANAttackManager
  • setup

old

    gan_attacker = GAN_Attack(
        client_2,
        target_label,
        generator,
        optimizer_g,
        criterion,
        nz=nz,
        device=device,
    )

    global_model = Net()
    global_model.to(device)
    server = FedAvgServer(clients, global_model)

new

    manager = GANAttackManager(
        target_label,
        generator,
        optimizer_g,
        criterion,
        nz=nz,
    )
    GANAttackFedAvgClient = manager.attach(FedAvgClient)

    net_2 = Net()
    client_2 = GANAttackFedAvgClient(net_2, user_id=1)
    optimizer_2 = optim.SGD(
        client_2.parameters(), lr=0.02, weight_decay=1e-7, momentum=0.9
    )

    clients = [client_1, client_2]
    optimizers = [optimizer_1, optimizer_2]
  • training

old

    for epoch in range(5):
        for client_idx in range(client_num):
            client = clients[client_idx]
            trainloader = trainloaders[client_idx]
            optimizer = optimizers[client_idx]

            running_loss = 0.0
            for _, data in enumerate(trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                if epoch != 0 and client_idx == adversary_client_id:
                    fake_image = gan_attacker.attack(fake_batch_size)
                    inputs = torch.cat([inputs, fake_image])
                    labels = torch.cat(
                        [
                            labels,
                            torch.tensor([fake_label] * fake_batch_size, device=device),
                        ]
                    )

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = client(inputs)
                loss = criterion(outputs, labels.to(torch.int64))
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            print(
                f"epoch {epoch}: client-{client_idx+1}",
                running_loss / dataset_nums[client_idx],
            )

        server.update()
        server.distribtue()

        gan_attacker.update_discriminator()
        gan_attacker.update_generator(batch_size=64, epoch=1000, log_interval=100)

new

    for epoch in range(2):
        for client_idx in range(client_num):
            client = clients[client_idx]
            optimizer = optimizers[client_idx]

            for _, data in enumerate(trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                if epoch != 0 and client_idx == adversary_client_id:
                    fake_image = client.attack(fake_batch_size)
                    inputs = torch.cat([inputs, fake_image])
                    labels = torch.cat(
                        [
                            labels,
                            torch.tensor([fake_label] * fake_batch_size),
                        ]
                    )
    
                # zero the parameter gradients
                optimizer.zero_grad()
    
                # forward + backward + optimize
                outputs = client(inputs)
                loss = criterion(outputs, labels.to(torch.int64))
                client.backward(loss)
                optimizer.step()

       server.action()

Thank you very much! I have successfully run the code! Also, I would like to ask if you have any recommended parameters that have been tuned? so that I can reproduce quickly.

@akyo1o

How about the parameters written in the code or the original paper? Are they not working in your environment?

Today, I tuned the parameters, and achieved good results, thank you again!