/torch-nngraph

Graph Computation for nn

Primary LanguageLua

torch-nngraph

This package provides graphical computation for nn library in Torch7.

Requirements

### torch-graph

This library requires torch-graph package to be installed.

http://github.com/koraykv/torch-graph

graphviz

You do not need graphviz to be able to use this library, but if you have then you can display the graphs that you have created.

Installation

Right now, this repo is not distributed as part of torch-pkg or luarocks system. For installation follow these steps.

	git clone git://github.com/koraykv/torch-graph.git
	cd torch-graph
	torch-pkg deploy 
	cd ..
	git clone git://github.com/koraykv/torch-nngraph.git
	cd torch-nngraph
	torch-pkg deploy 

Usage

The aim of this library is to provide users of nn library with tools to easily create complicated architectures. Any given nn module or criterion is going to be bundled into a graph node. The __call operator of an instance of nn.Module and nn.Criterion is used to create architectures as if one is writing function calls.

One hidden layer network

require 'nngraph'

x1 = nn.Linear(20,10)()
mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
mlp = nn.gModule({x1},{mout})

x = torch.rand(20)
dx = torch.rand(1)
mlp:updateOutput(x)
mlp:updateGradInput(x,dx)
mlp:accGradParameters(x,dx)

-- draw graph
graph.dot(mlp.fg,'MLP')

A net with 2 inputs and 2 outputs

require 'nngraph'

x1=nn.Linear(20,20)()
x2=nn.Linear(10,10)()
m0=nn.Linear(20,1)(nn.Tanh()(x1))
m1=nn.Linear(10,1)(nn.Tanh()(x2))
madd=nn.CAddTable()({m0,m1})
m2=nn.Sigmoid()(madd)
m3=nn.Tanh()(madd)
gmod = nn.gModule({x1,x2},{m2,m3})

x = torch.rand(20)
y = torch.rand(10)

gmod:updateOutput({x,y})
gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)})
graph.dot(gmod.fg,'Big MLP')

Another net that uses container modules (like ParallelTable) that output a table of outputs.

	m = nn.Sequential()
	m:add(nn.SplitTable(1))
	m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
	input = nn.Identity()()
	input1,input2 = m(input,2)
	m3 = nn.JoinTable(1)({input1,input2})

	g = nn.gModule({input},{m3})

	indata = torch.rand(2,10)
	gdata = torch.rand(50)
	g:forward(indata)
	g:backward(indata,gdata)

	graph.dot(g.fg,'Forward Graph')
	graph.dot(g.bg,'Backward Graph')

A Multi-layer network where each layer takes output of previous two layers as input.

	input = nn.Identity()()
	L1 = nn.Tanh()(nn.Linear(10,20)(input))
	L2 = nn.Tanh()(nn.Linear(30,60)(nn.JoinTable(1)({input,L1})))
	L3 = nn.Tanh()(nn.Linear(80,160)(nn.JoinTable(1)({L1,L2})))

	g = nn.gModule({input},{L3})

	indata = torch.rand(10)
	gdata = torch.rand(160)
	g:forward(indata)
	g:backward(indata,gdata)

	graph.dot(g.fg,'Forward Graph')
	graph.dot(g.bg,'Backward Graph')