caogang/wgan-gp

gradients.norm(2, dim=1), dim=1?

Closed this issue · 6 comments

@caogang Thanks for your good code! But something confuses me in gan_cifar10.py

image

dim = 1? Why it is only normed in the second axis? I think it should be normed across all axis but the batch axis.

Oh, very good question. I haven't found this problem before. I implemented this code following the author of WGAN-GP https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py . Maybe it is a trick from the authors?:) However, do you try to norm across all axis but the batch axis? Maybe it would have better results:). If you get a better results, please tell me and start a Pull Request. Thank you.

@caogang It is exactly normed across all axis but the batch axis just like it is in the paper. As shown below (from the authors' code https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py), the output of the generator is reshaped to rank 2. In your code, the output of the generator is still rank 4.

image

yes, i also found this problem.

may be you can change the penalty term as below

gradient_penalty = ((gradients.view(gradients.size(0), -1).norm(
        2, dim=1) - 1)**2).mean() * LAMBDA

Yes, you are right. Thanks for your suggestions @LynnHo @eriche2016 . It is a bug :(, and it is fixed in the latest version. Hope it will be helpful.

klory commented

still has this in the MNIST code...