TimDettmers/sparse_learning

Need some help

jmandivarapu1 opened this issue · 1 comments

Hi ,

I am trying to run the mnist code but I am not sure about the pruning rate and death rate values to use.
so when I ran the code like this python main.py --data mnist --model lenet5 --save-features --bench --growth momentum --redistribution momentum
But my accuracy is high and it's dropping too low again. Can you provide me the right arguments which I need to pass to the input?

Screen Shot 2019-11-26 at 12 50 02 PM

The problem is that LeNet-5 has a bottleneck with very few connections and if you get by chance a configuration where no connection between one and the other layer exists no gradient flow happens and learning stops (only biases are learned in that case). You can avoid this by using random growth, which is able to find a pattern of connection after an epoch or you can try to run the same model with another random seed. From my experience, the training is usually stable after epoch 20 but with very few weights <5% one might need a couple of tries until a solid connection is established.

This is a problem solely with LeNet-5. Other network architectures do not have these extreme bottlenecks that LeNet-5 has and do not show this behavior.