fidler-lab/polyrnn-pp-pytorch

AttributeError: 'dict' object has no attribute 'grad_fn' occured when visualizing the network structure

Jacoobr opened this issue · 4 comments

Hi, @amlankar . When i try to visulize the MLE training network structure by add the **make_dot()** function from the package torchviz, the code i added in train(self, epoch) function of train_ce.py script as bellow:

    def train(self, epoch):
        print 'Starting training'
        self.model.train()

        accum = defaultdict(float)
        # To accumulate stats for printing
        for step, data in enumerate(self.train_loader):
            if self.global_step % self.opts['val_freq'] == 0:
                self.validate()
                self.save_checkpoint(epoch)             

            # Forward pass
            # data['img'] = Variable(data['img'], requires_grad=True)
            # data['fwd_poly'] = Variable(data['fwd_poly'], requires_grad=True)   # Variable data['fwd_poly'] used for correction interactive
            input1 = data['img']
            input2 = data['fwd_poly']
            output = self.model(input1.to(device), input2.to(device))
            ## used for generating '.dot' format network structure graph
            g = make_dot(output, params=dict(list(polyrnnpp.PolyRNNpp(self.opts).named_parameters())+[('input1', input1), ('input2', input2)]))
            g.render('./graph', view=False)

Then i get the 'grad_fn' Error like this:

 g = make_dot(output, params=dict(list(polyrnnpp.PolyRNNpp(self.opts).named_parameters())+[('input1', input1), ('input2', input2)]))
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 37, in make_dot
    output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
AttributeError: 'dict' object has no attribute 'grad_fn'

Why the grad_fn of Variable output just gone when the make_dot() function called and how can i fix the 'grad_fn' Error ? And i try to use SummaryWriter of tensorboardX to do the same thing also, when testing by generate_annotation.py script and training by mle_ce.py script, I got the same Error. What's more important, if you have some other way to visulize the MLE network structure please let me konw. Thank you. Appreciative for your reply.

I'm not very familiar with torchviz, will look into it.

Meanwhile,

  1. shouldn't you use self.model.named_parameters()?

  2. Does it work if you don't add the inputs to the list of parameters?

  3. pytorch 0.4 doesn't use variables anymore and torch.Tensor by default does not require grad. Can you try enabling grad for the two inputs? (after the forward pass and before calling make_dot?)

Hi, @amlankar , Thanks for your reply. I tried the suggestions that you mentioned step by step. Unfortunatly, i got the same error. And i have tested the attribute 'requires_grad' of two inputs both are True . So , i guess the error caused by the first parameter output of make_dot function. What's more, The keys of dict variable output like this ['vertex_logits', 'edge_logits', 'lengths', 'attention', 'poly_class', 'log_probs', 'pred_polys', 'logprob_sums', 'logits']. And then, i changed the first parameter of make_dot with output['pred_polys'] which with grad like this:

output = self.model(torch.tensor(input1.to(device), requires_grad=True), torch.tensor(input2.to(device), requires_grad=True))
            input1 = torch.tensor(input1.to(device), requires_grad=True)  #make input with grad
            input2 = torch.tensor(input2.to(device), requires_grad=True)
            **pred_polys = torch.tensor(output['pred_polys'], requires_grad=True)
            g = make_dot(pred_polys, params=dict(self.model.named_parameters()))**
            g.render('./graph', view=False)

The previous error just gone, and a new KeyError occured :

  File "/home/tzq-lxj/workStation/polygonRNN_pluss/code/Scripts/train/train_ce.py", line 176, in train
    g = make_dot(pred_polys, params=dict(self.model.named_parameters()))
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 70, in make_dot
    add_nodes(var.grad_fn)
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 59, in add_nodes
    add_nodes(u[0])
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 47, in add_nodes
    name = param_map[id(u)] if params is not None else ''
KeyError: 140081754346832

I'm sorry to bother you again. Why the KeyError occured when the make_dot func called? What's more appreciative , would you mind provide me a way to visualize the network structure of MLE ? This work will help me a lot. thx.

@Jacoobr for visualising the network, i guess you can use also use add_graph parameter in tensorboardX while training. I have never tried in this training. But normally I used to do that for visualising networks. :)

This is a great question and we haven't done this. We'd love it if someone could use existing tools to visualize the network in a friendly way.