disclaimer: this code is modified from pytorch-tutorial
I implement the Decoupled Neural Interfaces using Synthetic Gradients in pytorch. The paper uses synthetic gradient to decouple the layers among the network, which is pretty interesting since we won't suffer from update lock anymore. I test my model in mnist and almost the same performance, compared to the model updated with backpropagation.
- pytorch
- python 3.5
- torchvision
- seaborn (optional)
- matplotlib (optional)
- use multi-threading on gpu to analyze the speed
We ofter optimize NN by backpropogation, which is usually implemented in some well-known framework. However, is there another way for the layers in NN to communicate with other layers? Here comes the synthetic gradients! It gives us a way to allow neural networks to communicate, to learn to send messages between themselves, in a decoupled, scalable manner paving the way for multiple neural networks to communicate with each other or improving the long term temporal dependency of recurrent networks.
The neuron in each layer will automatically produces an error signal(δa_head) from synthetic-layers and do the optimzation. And how did the error signal generated? Actually, the network still does the backpropogation. While the error signal(δa) from the objective function is not used to optimize the neuron in the network, it is used to optimize the error signal(δa_head) produced by the synthetic-layer. The following is the illustration from the paper:
Achieve accuracy=96% (compared to the original model, which with accuracy=97%)
classify loss | gradient loss(log level) |
---|---|
cDNI classify loss | cDNI gradient loss(log level) |
---|---|
Achieve accuracy=96%, (compared to the original model, which with accuracy=98%)
classify loss | gradient loss(log level) |
---|---|
Right now I just implement the FCN, CNN versions, which are set as the default network structure.
python main.py --model_type mlp
or
python main.py --model_type cnn
python main.py --model_type mlp --conditioned True
python mlp.py
or
python cnn.py
- Deepmind's post on Decoupled Neural Interfaces Using Synthetic Gradients
- Decoupled Neural Interfaces using Synthetic Gradients
- Understanding Synthetic Gradients and Decoupled Neural Interfaces