hash2430/pitchtron

Does gradient reversal layer work?

CODEJIN opened this issue · 7 comments

Hi,

Thank you for your souce code so much! Currently, I am trying to replicate your code for my study, but I have a question about the gradient reversal layer(GRL). When I use GRL of your code, the backward function of GradientReversalLayer does not called in backward process. So, I tested the GRL with following code, the gradient of check_layer was not reversal.

import torch

class GRL(torch.nn.Module):
    def forward(self, x):
        return x

    def backward(self, x):
        print('#######################')
        assert False    # I want to check this....
        return -x

def gradient_reversal_layer(x):
    grl = GRL()
    return grl(x)


a = torch.nn.Linear(3, 4)
b = GRL()
c = torch.nn.Linear(4, 5)

optimizer = torch.optim.SGD(
    params= list(a.parameters()) + list(c.parameters()),
    lr= 0.1
    )

x = torch.randn(1,3)
y = torch.randn(1, 5)

y1 = x
y1 = a(y1)
y1 = c(y1)
loss1 = torch.nn.L1Loss()(y1, y)
optimizer.zero_grad()
loss1.backward()
print(loss1)
print(a.weight.grad)
print(c.weight.grad)

y2 = x
y2 = a(y2)
y2 = gradient_reversal_layer(y2)
y2 = c(y2)
loss2 = torch.nn.L1Loss()(y2, y)
optimizer.zero_grad()
loss2.backward()
print(loss2)
print(a.weight.grad)
print(c.weight.grad)

However, when I checked with other GRL implement (https://github.com/janfreyberg/pytorch-revgrad), I confirmed the gradient was reversal. Thus, could you check that current GRL implement works well? And if it works, please let me know which part in your code is critical to work the backward of GRL correctly.

You are correct.
I had discovered that issue while ago and fixed it but the commited code does not include that fixed version.
I will commit the fix when I have some time.
The problem in the commited version is that I subclassed 'Module' but, instead the right choice would be to subclass 'Function'
Function can have 'backward' while Module don't.

Got it. Thank you!

Thanks for letting me know.
Which version of pytorch do you use?

The pytorch version that I tested your code was 1.4.0. But I usually am using 1.5.1 version.

I only tested with 1.0.1 so I can't say about other versions, but in 1.0.1, ctx.saved_tensors returns 'tuple' even though I only saved one tensor. So I had to retrieve the 0th element of that tuple to get desired tensor.
Providing pytorch version compatibility is not part of my plan so far.
But it's good to know it works with your setting with tiny adjustment.
Good luck with your study!

Thank you!