Using dynamic growth & pruning?
varun19299 opened this issue ยท 7 comments
Hi, is it possible to use dynamic growth and pruning currently by just updating the masks each step?
I'm looking to implement something like RigL (Evci et al. 2020).
Yes, theoretically you can call the pruning function at each mini-batch iteration. If you look at the code, it is currently only called after the end of each epoch. You just need to put this function into the training loop to achieve pruning/growth at each step.
Another option that is already baked in, but comes with predefined static behavior is setting the prune_every_k_steps variable. Setting it to 1 would execute the prune/regrowth cycle with every mini-batch.
Thank you for your reply!
I'm having a bit of trouble understanding the working and objective of the calc_growth_redistribution method in core.py
. A few questions:
- Why loop over 1000 times?
- Line 457 seems to be redundant with line 459 resetting
residual
. Similarly,expected_var
isn't used elsewhere in the method. - At first glance, I assumed the function was finding the maximum growth possible (dropped connections + previously non zero connections)* threshold, but line 477 seems to accumulate mean_residual across the regrowth variable. Is this necessary?
Thanks for your comment. The method determines the redistribution of weights. There is the problem of what you do if weights are redistributed to layers that are already full (and cannot regrow weights) or if more weights are regrown that a layer can fit. I could keep track of these, but I found it easier and more general to anneal the redistribution over time (1000 iterations).
The residual is the overflow from full layers that are too full, and it is redistributed up to 1000 iterations. It isn't easy to redistribute the weights in some cases, and the annealing procedure does not converge in 1000 iterations. In this case, the best solution after 1000 iterations is taken, but this solution might not be 100% proportional to the metric used to determine the redistribution fractions.
I hope this makes it a bit clear, it is definitely a confusion function, and I see that I forgot to clean up some artifacts as you have pointed out. Let me know if you have more questions.
Hi Tim,
Thanks again for your reply! I have a few more questions :)
-
I think the ERK initialisation may not be correct: there isn't a check to see if a layer's capacity is exhausted, i.e.,
p_{ERK, layer_i} > 1
. (For instance, RigL marks such layers as dense and excludes them from the ERK distributed set). -
Also, is this correct?
sparse_learning/sparselearning/core.py
Line 193 in f99c2f2
(probably be abs(current_params - target_params) < tolerance
?)
- Even with this change, since the existing code doesn't explicitly check if capacity is reached, actual density is much lesser than the input density.
Below is a comparison of the existing snippet vs RigL's implementation. Since there is no check on capacity, the actual sparsity is lower than the intended one. In the below output, intended density was 0.2 (or 80% sparsity).
INFO:root:ERK block1.layer.0.conv1.weight: torch.Size([32, 16, 3, 3]) prob 0.7958154602321671
INFO:root:ERK block1.layer.0.conv2.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.0.convShortcut.weight: torch.Size([32, 16, 1, 1]) prob 6.631795501934726
INFO:root:ERK block1.layer.1.conv1.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.1.conv2.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.2.conv1.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block1.layer.2.conv2.weight: torch.Size([32, 32, 3, 3]) prob 0.5158063168171453
INFO:root:ERK block2.layer.0.conv1.weight: torch.Size([64, 32, 3, 3]) prob 0.37580174510963443
INFO:root:ERK block2.layer.0.conv2.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.0.convShortcut.weight: torch.Size([64, 32, 1, 1]) prob 3.2495797959480157
INFO:root:ERK block2.layer.1.conv1.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.1.conv2.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.2.conv1.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block2.layer.2.conv2.weight: torch.Size([64, 64, 3, 3]) prob 0.2468501659053481
INFO:root:ERK block3.layer.0.conv1.weight: torch.Size([128, 64, 3, 3]) prob 0.18237437630320497
INFO:root:ERK block3.layer.0.conv2.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.0.convShortcut.weight: torch.Size([128, 64, 1, 1]) prob 1.608210409219171
INFO:root:ERK block3.layer.1.conv1.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.1.conv2.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.2.conv1.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK block3.layer.2.conv2.weight: torch.Size([128, 128, 3, 3]) prob 0.12066183482686793
INFO:root:ERK fc.weight: torch.Size([10, 128]) prob 7.321502234135937
INFO:root:Overall sparsity 0.18062481420927468
INFO:root:========
INFO:root:Sparsity of var:fc.weight had to be set to 0.
INFO:root:Sparsity of var:block1.layer.0.convShortcut.weight had to be set to 0.
INFO:root:Sparsity of var:block2.layer.0.convShortcut.weight had to be set to 0.
INFO:root:Sparsity of var:block3.layer.0.convShortcut.weight had to be set to 0.
INFO:root:layer: block1.layer.0.conv1.weight, shape: torch.Size([32, 16, 3, 3]), density: 0.8874813710879286
INFO:root:layer: block1.layer.0.conv2.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.0.convShortcut.weight, shape: torch.Size([32, 16, 1, 1]), density: 1.0
INFO:root:layer: block1.layer.1.conv1.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.1.conv2.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.2.conv1.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block1.layer.2.conv2.weight, shape: torch.Size([32, 32, 3, 3]), density: 0.5752194071866203
INFO:root:layer: block2.layer.0.conv1.weight, shape: torch.Size([64, 32, 3, 3]), density: 0.4190884252359663
INFO:root:layer: block2.layer.0.conv2.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.0.convShortcut.weight, shape: torch.Size([64, 32, 1, 1]), density: 1.0
INFO:root:layer: block2.layer.1.conv1.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.1.conv2.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.2.conv1.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block2.layer.2.conv2.weight, shape: torch.Size([64, 64, 3, 3]), density: 0.2752835734393112
INFO:root:layer: block3.layer.0.conv1.weight, shape: torch.Size([128, 64, 3, 3]), density: 0.20338114754098363
INFO:root:layer: block3.layer.0.conv2.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.0.convShortcut.weight, shape: torch.Size([128, 64, 1, 1]), density: 1.0
INFO:root:layer: block3.layer.1.conv1.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.1.conv2.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.2.conv1.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: block3.layer.2.conv2.weight, shape: torch.Size([128, 128, 3, 3]), density: 0.13456025418115583
INFO:root:layer: fc.weight, shape: torch.Size([10, 128]), density: 1.0
INFO:root:Overall sparsity 0.2
Here's the source to produce this output.
Adding a threshold (something like growth = max(prob ,1)*weight.numel()
causes the code to take a long time to converge.
Great catch! Would you mind submitting a pull request for this? I feel like you are able to quickly pinpoint and fix this issue.
Sure, I would be happy to contribute.
Would you prefer adding RigL's implementation of ERK for this? It does better than trying to tune epsilon
for a given sparsity (as seen in the outputs above).
We used @TimDettmers's sparselearning to base our code for RigL-reproducibility.
Our code has also deviated significantly since then, but I could patch in the ERK initialisation change if its still welcome.