Make autograd ignore computing gradients for certain functions
MLpatzer opened this issue · 2 comments
Hi, I am a bit new to torch and autograd. Sorry if this is a bit of an obvious question.
I am trying to use a function not supported by autograd namely "torch.sort()" in order to then only backprop for certain components in my loss function. It seems since the sort would just give indexes to be used later on it should be some way to make autograd ignore it from its gradients but still execute. I've tried a few variants of doing this but can't seem to get it to work.
pseudcode would look something like this
function features(x)
...
end
function Loss(x,y)
feat=features(x)
inds,_=torch.sort(feat,true) -- ignore this line
x=x:index(1,inds[{{1,20}}])
y=y:index(1,inds[{{1,20}}])
...
end
I think the easiest way to do this is to define an nn module that returns inds on the forward pass and just returns gradOutput on the backward pass, and then functionalize that module.
I think you'd want something like this, outside of your loss function:
-- Build your own custom module
mymodule = {}
-- Make a sort function in your module, which just returns the sorted array
mymodule.sort = function(x)
local sorted, _ = torch.sort(x)
return sorted
end
-- Define the gradient for the module
grad.overload.module("mymodule", mymodule, function(module)
module.gradient("sort", {
function(g, ans, x)
-- You need to define the gradient of sort here
-- Would involve "unsorting" g, and returning it
end
})
end)
You can't ignore the gradient of sort
, because it changes the indexing for subsequent use of your data, which needs to be undone in the backwards pass.
Also, in autograd, you don't want to use function calls of the form x:blah()
, because they are often in-place, and we don't support that in autograd. You'll want to rewrite it as torch.blah(x)
.