Weight Agnostic Neural Networks is a new type of neural network, where the weights of all the neurons are shared and the structure of the network is what matters.
This package implements Weight Agnostic Neural Networks for Go, and is inspired by this paper from June 2019:
"Weight Agnostic Neural Networks" by Adam Gaier and David Ha. (PDF | Interactive version | Google AI blog post)
- All activation functions are benchmarked at the start of the program and the results are taken into account when calculating the complexity of a network.
- All networks can be translated to a Go statement, using the wonderful jennifer package (work in progress, there are a few kinks that needs to be ironed out).
- Networks can be saved as
SVG
diagrams. This feature needs more testing. - Neural networks can be trained and used. See the
cmd
folder for examples. - A random weight is chosen when training, instead of looping over the range of the weight. The paper describes both methods.
- After the network has been trained, the optimal weight is found by looping over all weights (with a step size of
0.0001
). - Increased complexity counts negatively when evolving networks. This optimizes not only for less complex networks, but also for execution speed.
- The diagram drawing routine plots the activation functions directly onto the nodes, together with a label. This can be saved as an SVG file.
This is a simple example, for creating a network that can recognize one of four shapes:
package main
import (
"fmt"
"os"
"github.com/xyproto/wann"
)
func main() {
// Here are four shapes, representing: up, down, left and right:
up := []float64{
0.0, 1.0, 0.0, // o
1.0, 1.0, 1.0} // ooo
down := []float64{
1.0, 1.0, 1.0, // ooo
0.0, 1.0, 0.0} // o
left := []float64{
1.0, 1.0, 1.0, // ooo
0.0, 0.0, 1.0} // o
right := []float64{
1.0, 1.0, 1.0, // ooo
1.0, 0.0, 0.0} // o
// Prepare the input data as a 2D slice
inputData := [][]float64{
up,
down,
left,
right,
}
// Target scores for: up, down, left, right
correctResultsForUp := []float64{1.0, -1.0, -1.0, -1.0}
// Prepare a neural network configuration struct
config := &wann.Config{
InitialConnectionRatio: 0.2,
Generations: 2000,
PopulationSize: 500,
Verbose: true,
}
// Evolve a network, using the input data and the sought after results
trainedNetwork, err := config.Evolve(inputData, correctResultsForUp)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
// Now to test the trained network on 4 different inputs and see if it passes the test
upScore := trainedNetwork.Evaluate(up)
downScore := trainedNetwork.Evaluate(down)
leftScore := trainedNetwork.Evaluate(left)
rightScore := trainedNetwork.Evaluate(right)
if config.Verbose {
if upScore > downScore && upScore > leftScore && upScore > rightScore {
fmt.Println("Network training complete, the results are good.")
} else {
fmt.Println("Network training complete, but the results did not pass the test.")
}
}
// Save the trained network as an SVG image
if config.Verbose {
fmt.Print("Writing network.svg...")
}
if err := trainedNetwork.WriteSVG("network.svg"); err != nil {
fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
if config.Verbose {
fmt.Println("ok")
}
}
Here is the resulting network generated by the above program:
This makes sense, since taking the third number in the input data (index 2), running it through a swish function and then inverting it should be a usable detector for the up
pattern.
- The generated networks may differ for each run.
This requires Go 1.11 or later.
Clone the repository:
git clone https://github.com/xyproto/wann
Enter the cmd/evolve
directory:
cd wann/cmd/evolve
Build and run the example:
go build && ./evolve
Take a look at the best network for judging if a set of numbers that are either 0 or 1 are of one category:
xdg-open network.svg
(If needed, use your favorite SVG viewer instead of the xdg-open
command).
- Adding convolution nodes might give interesting results.
This is an experimental feature and a work in progress!
The idea is to generate one large expression from all the expressions that each node in the network represents.
Right now, this only works for networks that has a depth of 1.
For example, adding these two lines to cmd/evolve/main.go
:
// Output a Go function for this network
fmt.Println(trainedNetwork.GoFunction())
Produces this output:
func f(x float64) float64 { return -x }
The plan is to output a function that takes the input data instead, and refers to the input data by index. Support for deeper networks also needs to be added.
There is a complete example for outputting Go code in cmd/gofunction
.
- Version: 0.3.2
- License: BSD-3
- Author: Alexander F. Rødseth <xyproto@archlinux.org>