mlverse/torch

Backward pass fails on torch_max with the "inplace operation" error

Opened this issue · 1 comments

Seems that the backward pass does not work with torch_max function, or I made a mistake somewhere. R code:

m_tensor <- torch_tensor(matrix(1:8, nrow = 2), dtype = torch_float64(), requires_grad = TRUE)
n <- torch_max(m_tensor, dim = 2)[[1]]
n_sum <- torch_sum(n)
n_sum$backward()
m_tensor$grad

Fails with an error:

Error in (function (self, inputs, gradient, retain_graph, create_graph) :
one of the variables needed for gradient computation has been modified by an inplace operation

Analogous code in Python seems to work fine:

m_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float64, requires_grad=True)
n = torch.max(m_tensor, dim=1).values
n_sum = torch.sum(n)
n_sum.backward()
m_tensor.grad

Returns:

tensor([[0., 0., 0., 1.],
[0., 0., 0., 1.]], dtype=torch.float64)

I would greatly appreciate your help, thanks!

torch_amax works fine by the way:

m_tensor <- torch_tensor(matrix(1:8, nrow = 2), dtype = torch_float64(), requires_grad = TRUE)
n <- torch_amax(m_tensor, dim = 2)
n_sum <- torch_sum(n)
n_sum$backward()
m_tensor$grad

Returns:

torch_tensor
0 0 0 1
0 0 0 1
[ CPUDoubleType{2,4} ]