Question about Remove Hook
QiyaoWei opened this issue · 6 comments
Dear Shaojie,
Hi there! This is Qiyao, a huge fan of your works! I am writing to ask a question about the lines. I notice that if I remove these lines the training does not work, but I am having a hard time figuring out why? In my understanding, the program should never be creating more than one hook in a single forward pass, so I don't see the purpose of having this check here? For example, this tutorial does not check for the hook, so I am confused as to what is happening here?
Hi @QiyaoWei ,
Great question. In the tutorial, as you might have noticed, the fixed point z
was cloned into z0
and passed through another layer. This is therefore an extra computation cost.
In contrast, in the implementation provided in this repo, we don't even need to pay this extra computation cost (or cloning anything). This requires us to directly apply the hook on z1s
. However, if we do NOT remove the hook, then this line will recursively call the hook (as autograd.grad
will call backward hook), and the program will hang accordingly.
I hope this clarifies things for you.
Gotcha. If I may ask a follow-up question, also about hooks---Say I want to do something like Jacobian Regularization, only that the regularization term comes from the solver, i.e. aside from z1
, I also get my regularization term from this return. Is there a way to allow that regularization loss to backprop through the solver while keeping everything else intact? Basically I want to keep the forward solver wrapped in torch.no_grad()
, but somehow allow my regularization loss to be outside torch.no_grad()
. I'm not sure if hooks will work in this case, so is that even possible?
I don't think I fully understand. If your goal is to backprop through the solver, then you will have to pay all the intermediate activation memory costs anyway--- so there's no point to do torch.no_grad()
. Or do you believe your regularization loss will only use a tiny portion of the solver information?
yep that's exactly right. Ideally my regularization loss would only need to use less than half of the solver stacktrace, so I think there is still merit in investigating whether I can keep the original hook routine
Then maybe it's possible to do this at a finer granularity - that is, put the torch.no_grad
in the solver implementation and only keep the part that you do need to be differentiable. Another way is to break the fixed-point solving into 2 parts: one part within torch.no_grad()
and the other part with torch.enable_grad()
.
Ah I see. That makes sense. Thanks a lot for the quick response!