Logic problem with calc_gradient_penalty in CNN case
Closed this issue · 2 comments
Right now, you're getting the norm of the gradient in gan_mnist with gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
.
gradient_penalty
is of shape [BATCH_SIZE ,1, 28, 28]
. We want to calculate the norm of the gradient PER SAMPLE, and then use that as an error metric. That means we need to collapse gradient_penalty
into one value per sample, or to shape [BATCH_SIZE, 1]
or just [BATCH_SIZE]
.
But, gradients.norm(dim=1)
collapses it to size [BATCH_SIZE, 28, 28]
, which isn't right.
Instead, gradients needs to be reshaped to be flat before you take the norm.
I monitored the value of gradient_penalty, and doing it the way it has now explodes to around 10000 for gan_mnist
, even when the networks gradients were reasonable, so formal reason aside, I'm pretty sure there's a bug.
Great library by the way, it's made my life really easy. Thanks for posting it.
Want me to make a PR?
What do you think?
I changed the code from
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty
to
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients_reshaped = gradients.view(gradients.size()[0], -1)
gradient_penalty = ((gradients_reshaped.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
I think that fixes the problem. What do you think?
Yeah, the bug is solved for gan_cifar. However I have not checked it for gan_mnist. Ops, it is my fault.:( So welcome for your PR.