Different grad of input between pscan and rnn
GlassyWing opened this issue · 2 comments
GlassyWing commented
I wrote a test to compare the computation results of pscan with those of RNN, but the gradients with respect to
the input x are inconsistent.
def test_pscan():
# (B, L, D, N)
a = torch.tensor([1, 2, 3, 4], dtype=float).reshape(1, 4, 1, 1)
x = torch.tensor([2, 1, 3, 3], dtype=float).reshape(1, 4, 1, 1).requires_grad_(True)
h = pscan(a, x)
print("h output")
print(h.shape, h)
h.backward(torch.ones_like(h))
print("x, grad")
print(x.shape, x.grad)
def test_rnn():
# (B, L, D, N)
a = torch.tensor([1, 2, 3, 4], dtype=float).reshape(1, 4, 1, 1)
x = torch.tensor([2, 1, 3, 3], dtype=float).reshape(1, 4, 1, 1).requires_grad_(True)
h_prev = torch.zeros(1, 1, 1)
print("h output")
for t in range(x.shape[1]):
h_next = a[:, [t]] * h_prev + x[:, [t]]
h_prev = h_next
print(h_prev)
h_prev.backward(torch.ones_like(h_prev))
print("x, grad")
print(x.shape, x.grad)
if __name__ == "__main__":
print("===================")
print("Test pscan")
print("===================")
test_pscan()
print("===================")
print("Test RNN")
print("===================")
test_rnn()
The results:
===================
Test pscan
===================
h output
torch.Size([1, 4, 1, 1]) tensor([[[[ 2.]],
[[ 5.]],
[[18.]],
[[75.]]]], dtype=torch.float64, grad_fn=<PScanBackward>)
x, grad
torch.Size([1, 4, 1, 1]) tensor([[[[33.]],
[[16.]],
[[ 5.]],
[[ 1.]]]], dtype=torch.float64)
===================
Test RNN
===================
h output
tensor([[[[2.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([[[[5.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([[[[18.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([[[[75.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
x, grad
torch.Size([1, 4, 1, 1]) tensor([[[[24.]],
[[12.]],
[[ 4.]],
[[ 1.]]]], dtype=torch.float64)
Is this a bug, or have I missed something?
GlassyWing commented
sorry, i forget the grad of x(t) need sum of [dH(t)/dx(t), ..., dH(T)/dx(t)]
alxndrTL commented
Hello, thanks for updating :)
I was going to look into it but now I guess that's ok.