gradients.norm(2, dim=1), dim=1?
Closed this issue · 6 comments
@caogang Thanks for your good code! But something confuses me in gan_cifar10.py
dim = 1? Why it is only normed in the second axis? I think it should be normed across all axis but the batch axis.
Oh, very good question. I haven't found this problem before. I implemented this code following the author of WGAN-GP https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py . Maybe it is a trick from the authors?:) However, do you try to norm across all axis but the batch axis? Maybe it would have better results:). If you get a better results, please tell me and start a Pull Request. Thank you.
@caogang It is exactly normed across all axis but the batch axis just like it is in the paper. As shown below (from the authors' code https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py), the output of the generator is reshaped to rank 2. In your code, the output of the generator is still rank 4.
yes, i also found this problem.
may be you can change the penalty term as below
gradient_penalty = ((gradients.view(gradients.size(0), -1).norm(
2, dim=1) - 1)**2).mean() * LAMBDA
Yes, you are right. Thanks for your suggestions @LynnHo @eriche2016 . It is a bug :(, and it is fixed in the latest version. Hope it will be helpful.
still has this in the MNIST code...