wengong-jin/icml18-jtnn

Index Error in MPN.py

maxbernhard opened this issue · 3 comments

Hey,
when running your model with my own dataset I'm getting the following error:

Traceback (most recent call last):
  File "pretrain.py", line 69, in <module>
    loss, kl_div, wacc, tacc, sacc, dacc, pacc = model(batch, beta=0)
  File "/home/.conda/envs/jtnn_pytorch/lib/python2.7/site-packages/torch/nn/modules/module.py", line 224, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/molopt/jtnn/jtprop_vae.py", line 76, in forward
    tree_mess, tree_vec, mol_vec = self.encode(mol_batch)
  File "/home/molopt/jtnn/jtprop_vae.py", line 60, in encode
    mol_vec = self.mpn(mol2graph(smiles_batch))
  File "/home/molopt/jtnn/mpn.py", line 75, in mol2graph
    agraph[a,i] = b
IndexError: index 6 is out of range for dimension 0 (of size 6)

--> PS: I changed MAXNB in jtnn_dec and jtnn_enc to 32.

Ok so I figured out that the error is related to the input dataaset. Are there any known limitations towards the input dataset?

Hi,
I think you need to increase the MAX_NB value. This variable is defined here
https://github.com/wengong-jin/icml18-jtnn/blob/master/jtnn/mpn.py#L12

Thank you that resolved the issue!