falkaer/artist-group-factors

GradNorm retain_graph OOM

hongyuntw opened this issue · 2 comments

Hi author
I refer to your code which include GradNorm part, and rewrite for my own transformer based model training.
Everything is good, but when the iteration growth up, the error CUDA out of memory. will occur
I would like to know if you have encountered the same error in your training stage?
I thought that is because of retain_graph

loss.backward(retain_graph=True)
and
gygw = torch.autograd.grad(task_losses[k], W.parameters(), retain_graph=True)

Am I right?
And is there any method to avoid this error when iteration growth up?

Thank you for your nice code :)

Hi,

Sorry for the late reply, I somehow didn't get a notification or missed it. Please see my reply to this issue. The issue is indeed the retain_graph option on the first backwards pass, and it should be possible to reduce the peak memory use by computing the GradNorm specific terms first with retain_graph=True, and then backpropagating through the whole model with retain_graph=False afterwards.

Should be fixed by c8d3d06.