/gradflow-check

Check gradient flow in Pytorch

Primary LanguagePython

Gradient flow check in Pytorch

Check that the gradient flow is proper in the network by recording the average gradients per layer in every training iteration and then plotting them at the end. If the average gradients are zero in the initial layers of the network then probably your network is too deep for the gradient to flow.

Usage

loss = self.criterion(outputs, labels)  
loss.backward()
plot_grad_flow(model.named_parameters()) # version 1
# OR
plot_grad_flow_v2(model.named_parameters()) # version 2

Result

Bad gradient flow:

Bad gradient

Good gradient flow:

Good gradient

Repo based on this pytorch discuss post.