Need some help
jmandivarapu1 opened this issue · 1 comments
Hi ,
I am trying to run the mnist code but I am not sure about the pruning rate and death rate values to use.
so when I ran the code like this python main.py --data mnist --model lenet5 --save-features --bench --growth momentum --redistribution momentum
But my accuracy is high and it's dropping too low again. Can you provide me the right arguments which I need to pass to the input?
The problem is that LeNet-5 has a bottleneck with very few connections and if you get by chance a configuration where no connection between one and the other layer exists no gradient flow happens and learning stops (only biases are learned in that case). You can avoid this by using random growth, which is able to find a pattern of connection after an epoch or you can try to run the same model with another random seed. From my experience, the training is usually stable after epoch 20 but with very few weights <5% one might need a couple of tries until a solid connection is established.
This is a problem solely with LeNet-5. Other network architectures do not have these extreme bottlenecks that LeNet-5 has and do not show this behavior.