[Question] Grad clipping for flows
pi-tau opened this issue · 1 comments
Tutorial: 11
Describe the bug
A question regarding gradient clipping for flows.
Since our flow model is actually a composition of flows, and each flow in the composition is a separate neural net, I was wondering what would be the correct approach to clip the gradients during training.
In Cell 17 I can see that the pl.Trainer is defined with a gradient clip value of 1.0. Not really sure how this is applied, but I would guess it is something like that:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Wouldn't it be better if we clip the gradients separately per network, and not together for the entire model?
for flow in model.flows:
torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
I would imagine that this approach would allow for faster training because we are not clipping the gradient signal that much, and yet still be stable. With the original approach (clipping the entire model) I am concerned that if we have a deep flow (20-30+ flows) then learning might be very slow.
Personally, I haven't experimented with the clipping strategy. I actually trained a flow on CelebA (resized 32x32) without any clipping and the results were fine 🤷
Any comments on the topic would be appreciated : )
Hi @pi-tau, thanks for your question! By default, PyTorch Lightning is indeed using the clip_grad_norm function. You can change it to clipping the value of each individual gradient by setting gradient_clip_algorithm="value"
. The purpose of gradient clipping is just to prevent very large gradient spikes, that could be caused by accidentally placing an input to the tails of the prior distribution. It is not strictly necessary, and you often can train small flows without gradient clipping. However, I experienced couple of times that the training can suddenly go to NaN losses, which happens less often when using gradient clipping. For large networks, you could indeed do the norm clipping per network, or simply increase the norm. However, a norm of 1 is actually already quite large, so with common flow sizes, I don't expect it to run in an issue :)
Feel free to reopen the issue if you have follow up questions!