fmassa/optimize-net

graphgen doesn't work with DataParallelTable

szagoruyko opened this issue · 2 comments

repro:

require 'cunn'
require 'cudnn'
local generateGraph = require 'optnet.graphgen'
local iterm = require 'iterm'

local model = nn.DataParallelTable(1)

model:add(cudnn.SpatialConvolution(3,96,7,7,3,3),1)
model:add(cudnn.SpatialConvolution(3,96,7,7,3,3),2)

model:cuda()

local input = torch.randn(32,3,224,224):cuda()

iterm.dot(generateGraph(model, input))

gives

/opt/rocks/distro/install/bin/luajit: /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:141: Unwritable object <userdata> at <?>.<?>.updateOutput.basefunc.errcheck.C
stack traceback:
    [C]: in function 'error'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:141: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:200: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:200: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:200: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:235: in function 'writeObject'
    ...istro/install/share/lua/5.1/cudnn/SpatialConvolution.lua:470: in function 'write'
    /opt/rocks/distro/install/share/lua/5.1/torch/File.lua:210: in function 'writeObject'
    /opt/rocks/distro/install/share/lua/5.1/nn/Module.lua:107: in function 'clone'
    .../distro/install/share/lua/5.1/cunn/DataParallelTable.lua:634: in function 'applyChanges'
    .../distro/install/share/lua/5.1/cunn/DataParallelTable.lua:472: in function 'apply'
    /opt/rocks/distro/install/share/lua/5.1/optnet/graphgen.lua:221: in function 'generateGraph'
    /tmp/graphgen_fail.lua:15: in main chunk

Thanks for the example Sergey !
I managed to reduce the problem to the following snippet (independent of graphgen or cudnn):

require 'cunn'
model = nn.DataParallelTable(1)
model:add(nn.SpatialConvolution(3,96,7,7,3,3),1)
model:add(nn.SpatialConvolution(3,96,7,7,3,3),2)
model:cuda()
input = torch.randn(32,3,224,224):cuda()
function f(m)
  local ff = m.updateOutput
  m.updateOutput = function(self, i)
    return ff(self, i)
  end
end
model:apply(f)
model:forward(input);

This behaviour is not compatible with the other modules, where everything work as expected.
This seems like a bug in nn.DataParallelTable, or am I missing something ?

@szagoruyko I proposed a quick fix for this issue in 0c7c216 . The test snippet you sent works. Could you check if it works for your models ?