tamarott/SinGAN

Possible one-line solution for Runtime error (variables modified in-place)

williantrevizan opened this issue · 8 comments

Hi, thanks for the repository and this amazing work!

I opened this issue because it might provide a solution for the runtime error reported by cdjameson in another topic, that happens in newer versions of torch ('one of the variables needed for gradient computation has been modified by an inplace operation...'), that seems to be more straightfoward than the solution that Clefspear99 is proposing as a pool request.

The problem happens in the function train_single_scale() in training.py
This function is composed basically of two sequential loops, one for optimizing the discriminator D, and the other for optimizing the generator G. At the end of the first loop, a fake image is generated by the generator. As soon as the second loop starts, this fake image is passed throught the discriminator, with generates a patch discrimination map, which is then used to calculate the loss errG. The command errG.backwards() calculates the gradients which are used for the optimization of netG weights via optimizerG.step(). The first time we go through this second loop everything runs smoothly and the optimizer changes netG weights inplace. However, the second time we go through this loop, the same fake image is used to calculate the loss (that is, the fake image that had been generated with a previous set of netG weights). Therefore, once we call the backwards function, the computational graph will point back to netG weights that were in their original version, before the optimization step. Newer versions of torch are able to catch this inconsistency and that seems to be the reason why the error occurs.

So, instead of downgrading torch, a simple solution would be to add the line,

fake = netG(noise.detach(), prev.detach())

right in the beggining of the second loop, to always recalculate the fake image with the correct weights.

tamarott, I think this might solve this problem. If you allow, I will submit a pull request with this modification.

This is a possible solution, but pat attention that it changed the optimization process and therefore might change performances.
So the results won't necessarily be identical to the original version.

You are right, I'll pay atention to that! I ran a few tests with the application I'm working on, and it seems to be doing fine with this modification, but I didn't stress these tests too much.

About the optimization process, when I first thought about your paper and code, it made sense to me that conceptually the fake image should be recalculated at every step on that loop (for optimizing G). However what seems to be going on is that the adversarial loss is kept fixed (because you use the same fake image 3 times) and only the reconstruction loss is updated inside the loop. Is there a reason why that should work better?

We found it to work better empirically. But other solutions might also work.
Just be careful and make sure performances are the same.

Nice, thanks a lot!!

Thanks @williantrevizan, Your fix worked for me

It works for me well too.
You saved my time!!
Thanks a lot!

WZLHQ commented

thanks. You realy save my time!

Thank you @williantrevizan! Confirmed that this solution works on torch==1.12.0.