alxndrTL/mamba.py

Different grad of input between pscan and rnn

GlassyWing opened this issue · 2 comments

I wrote a test to compare the computation results of pscan with those of RNN, but the gradients with respect to
the input x are inconsistent.

def test_pscan():

    # (B, L, D, N)
    a = torch.tensor([1, 2, 3, 4], dtype=float).reshape(1, 4, 1, 1)
    x = torch.tensor([2, 1, 3, 3], dtype=float).reshape(1, 4, 1, 1).requires_grad_(True)

    h = pscan(a, x)
    print("h output")
    print(h.shape, h)
    h.backward(torch.ones_like(h))
    print("x, grad")
    print(x.shape, x.grad)

def test_rnn():

    # (B, L, D, N)
    a = torch.tensor([1, 2, 3, 4], dtype=float).reshape(1, 4, 1, 1)
    x = torch.tensor([2, 1, 3, 3], dtype=float).reshape(1, 4, 1, 1).requires_grad_(True)

    h_prev = torch.zeros(1, 1, 1)
    print("h output")
    for t in range(x.shape[1]):
        h_next = a[:, [t]] * h_prev + x[:, [t]]
        h_prev = h_next
        print(h_prev)
    
    h_prev.backward(torch.ones_like(h_prev))
    print("x, grad")
    print(x.shape, x.grad)

if __name__ == "__main__":
    print("===================")
    print("Test pscan")
    print("===================")
    test_pscan()
    print("===================")
    print("Test RNN")
    print("===================")
    test_rnn()

The results:

===================
Test pscan
===================
h output
torch.Size([1, 4, 1, 1]) tensor([[[[ 2.]],

         [[ 5.]],

         [[18.]],

         [[75.]]]], dtype=torch.float64, grad_fn=<PScanBackward>)
x, grad
torch.Size([1, 4, 1, 1]) tensor([[[[33.]],

         [[16.]],

         [[ 5.]],

         [[ 1.]]]], dtype=torch.float64)
===================
Test RNN
===================
h output
tensor([[[[2.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([[[[5.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([[[[18.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
tensor([[[[75.]]]], dtype=torch.float64, grad_fn=<AddBackward0>)
x, grad
torch.Size([1, 4, 1, 1]) tensor([[[[24.]],

         [[12.]],

         [[ 4.]],

         [[ 1.]]]], dtype=torch.float64)

Is this a bug, or have I missed something?

sorry, i forget the grad of x(t) need sum of [dH(t)/dx(t), ..., dH(T)/dx(t)]

Hello, thanks for updating :)
I was going to look into it but now I guess that's ok.