Orkis-Research/Quaternion-Recurrent-Neural-Networks

Gradients not going backward

Opened this issue · 3 comments

Hello,

I am trying a simple MNIST network:
model = nn.Sequential(QuaternionLinear(784, 128), nn.LeakyReLU(), QuaternionLinear(128, 10), nn.LogSoftmax(dim=1))

However, the loss is never decreasing, I could not find any equations for back propagation of the gradients, but assuming the gradients are not going backward, is this expected? if not how to debug? Would be really helpful, if could provide an example on a simple task (MNIST). Thanks!

Using optim.SGD with lr=0.003, momentum=0.9, ran 10 epochs

Hi,

Could you please provide your entire code?

Thanks.

Sure:
`
class QuaternionLinearFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, input, r_weight, i_weight, j_weight, k_weight, bias=None):
    ctx.save_for_backward(input, r_weight, i_weight, j_weight, k_weight, bias)
    #check_input(input)
    cat_kernels_4_r = torch.cat([r_weight, -i_weight, -j_weight, -k_weight], dim=0)
    cat_kernels_4_i = torch.cat([i_weight,  r_weight, -k_weight, j_weight], dim=0)
    cat_kernels_4_j = torch.cat([j_weight,  k_weight, r_weight, -i_weight], dim=0)
    cat_kernels_4_k = torch.cat([k_weight,  -j_weight, i_weight, r_weight], dim=0)
    cat_kernels_4_quaternion = torch.cat([cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1)
    if input.dim() == 2 :
        if bias is not None:
            return torch.addmm(bias, input, cat_kernels_4_quaternion)
        else: 
            return torch.mm(input, cat_kernels_4_quaternion)
    else:
        output = torch.matmul(input, cat_kernels_4_quaternion)
        if bias is not None:
            return output+bias
        else:
            return output

# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
    #print('hello inside ')
    
    input, r_weight, i_weight, j_weight, k_weight, bias = ctx.saved_tensors
    grad_input = grad_weight_r = grad_weight_i = grad_weight_j = grad_weight_k = grad_bias = None
    
    input_r = torch.cat([r_weight, -i_weight, -j_weight, -k_weight], dim=0)
    input_i = torch.cat([i_weight,  r_weight, -k_weight, j_weight], dim=0)
    input_j = torch.cat([j_weight,  k_weight, r_weight, -i_weight], dim=0)
    input_k = torch.cat([k_weight,  -j_weight, i_weight, r_weight], dim=0)
    cat_kernels_4_quaternion_T = Variable(torch.cat([input_r, input_i, input_j, input_k], dim=1).permute(1,0), requires_grad=False)

    r = get_r(input)
    i = get_i(input)
    j = get_j(input)
    k = get_k(input)
    input_r = torch.cat([r, -i, -j, -k], dim=0)
    input_i = torch.cat([i,  r, -k, j], dim=0)
    input_j = torch.cat([j,  k, r, -i], dim=0)
    input_k = torch.cat([k,  -j, i, r], dim=0)
    input_mat = Variable(torch.cat([input_r, input_i, input_j, input_k], dim=1), requires_grad=False)

    r = get_r(grad_output)
    i = get_i(grad_output)
    j = get_j(grad_output)
    k = get_k(grad_output)
    input_r = torch.cat([r, i, j, k], dim=1)
    input_i = torch.cat([-i,  r, k, -j], dim=1)
    input_j = torch.cat([-j,  -k, r, i], dim=1)
    input_k = torch.cat([-k,  j, -i, r], dim=1)
    grad_mat = torch.cat([input_r, input_i, input_j, input_k], dim=0)

    if ctx.needs_input_grad[0]:
        grad_input  = grad_output.mm(cat_kernels_4_quaternion_T)
    if ctx.needs_input_grad[1]:
        grad_weight = grad_mat.permute(1,0).mm(input_mat).permute(1,0)
        unit_size_x = r_weight.size(0)
        unit_size_y = r_weight.size(1)
        grad_weight_r = grad_weight.narrow(0,0,unit_size_x).narrow(1,0,unit_size_y)
        grad_weight_i = grad_weight.narrow(0,0,unit_size_x).narrow(1,unit_size_y,unit_size_y)
        grad_weight_j = grad_weight.narrow(0,0,unit_size_x).narrow(1,unit_size_y*2,unit_size_y)
        grad_weight_k = grad_weight.narrow(0,0,unit_size_x).narrow(1,unit_size_y*3,unit_size_y)
    if ctx.needs_input_grad[5]:
        grad_bias   = grad_output.sum(0).squeeze(0)
    
    return grad_input, grad_weight_r, grad_weight_i, grad_weight_j, grad_weight_k, grad_bias`

model2 = nn.Sequential(QuaternionLinear(784, 128),
nn.LeakyReLU(),
QuaternionLinear(128, 10),
nn.LogSoftmax(dim=1))

optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
time0 = time()
epochs = 10

`for e in range(epochs):
running_loss = 0
for images, labels in trainloader:
# Flatten MNIST images into a 784 long vector
images = images.view(images.shape[0], -1)

    # Training pass
    optimizer.zero_grad()
    
    output = model2(images)
    loss = criterion(output, labels)
    
    #This is where the model learns by backpropagating
    loss.backward()
    
    #And optimizes its weights here
    optimizer.step()
    #model2.step()
    
    running_loss += loss.item()
else:
    print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader)))

print("\nTraining Time (in minutes) =",(time()-time0)/60)

`

(the view clipped out)

So the first thing to do is to use AutoGrad by calling QuaternionLinearAutograd instead of QuaternionLinear. Then, and even it should be the same, I would suggest you to update everything from the Pytorch Quaternion Neural Networks repository (that contains also more examples). QuaternionLinear is an optimized layer for memory usage and is two times slower than QuaternionAutograd, but it should work. Let me know if QuaternionLinearAutograd solves the problem (please also update to the latest repository)