LostOxygen/poison_froggo

CUDA memory keeps growing in "forward_step"

Closed this issue · 3 comments

def forward_step(model: nn.Sequential, img: torch.Tensor, target_image: torch.Tensor, lr: float, target_logits) -> torch.Tensor:
  img.detach_()
  img.requires_grad = True
  
  logits = model(img)
  loss = torch.norm(logits - target_logits)
  model.zero_grad()
  loss.backward()
  
  img_grad = img.grad.data
  perturbed_img = img - lr*img_grad
  return perturbed_img
# when you call this function
new_image = forward_step(model, old_image, target_image, learning_rate, copy.deepcopy(target_feat_rep))

Hi, friend.

I found that when calling "create_perturbed_dataset", the CUDA memory keeps growing.

Change the code as above can solve the problem.

But do you know why? I can't figure it out.

Does the attack then still work normally? Since you removed the call to perform the forward pass with the pytorch autograd system.

Sometimes the gradient trees which are generated by pytorchs autograd system get really large which causes memory leaks if you copy them over and over again in new variables. The .detach() call can solve this problem and maybe this is also the cause here. So if it stills works fine without the extra with torch.enable_grad() context, this could be the reason.

If so, feel free to open a pull request with the code changes :)

Thanks for reply.

My code snippets above can attack successfully and CUDA memory doesn't increase.

But it also causes memory leaks without with torch.enable_grad(). Perhaps it is becaues of model.zero_grad(), but no matter, there is already a solution.

I will pull a changes request, thank you again.

Hello, dear author. I'm in some trouble. I tried to change the dataset dog to a traffic sign with a speed limit of 30 and fish to a traffic sign with a speed limit of 100
20ce977a4bf4bfb49865472dd1d9328

I only changed this part of the code because of the library version
4be0481ebeba1662a69408ec136f66e

I changed EPOCHS = 100 to EPOCHS = 15 to save time, but it still took me 4 hours to run code.

These are the only three changes I made to the code.

I don't understand why the result is like the picture below
Why didn't the attack succeed
2e4cc665904ef8db34d29c169d3cf86
d1db8c477d4ef37c9b4a08bb340c948

Hope to get your reply, thank you!