caogang/wgan-gp

Logic problem with calc_gradient_penalty in CNN case

Closed this issue · 2 comments

Right now, you're getting the norm of the gradient in gan_mnist with gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA.

gradient_penalty is of shape [BATCH_SIZE ,1, 28, 28]. We want to calculate the norm of the gradient PER SAMPLE, and then use that as an error metric. That means we need to collapse gradient_penalty into one value per sample, or to shape [BATCH_SIZE, 1] or just [BATCH_SIZE].

But, gradients.norm(dim=1) collapses it to size [BATCH_SIZE, 28, 28], which isn't right.

Instead, gradients needs to be reshaped to be flat before you take the norm.

I monitored the value of gradient_penalty, and doing it the way it has now explodes to around 10000 for gan_mnist, even when the networks gradients were reasonable, so formal reason aside, I'm pretty sure there's a bug.

Great library by the way, it's made my life really easy. Thanks for posting it.

Want me to make a PR?

What do you think?

I changed the code from

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty

to

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]


    gradients_reshaped = gradients.view(gradients.size()[0], -1)
    gradient_penalty = ((gradients_reshaped.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

I think that fixes the problem. What do you think?

Yeah, the bug is solved for gan_cifar. However I have not checked it for gan_mnist. Ops, it is my fault.:( So welcome for your PR.