Index Error in MPN.py
maxbernhard opened this issue · 3 comments
maxbernhard commented
Hey,
when running your model with my own dataset I'm getting the following error:
Traceback (most recent call last):
File "pretrain.py", line 69, in <module>
loss, kl_div, wacc, tacc, sacc, dacc, pacc = model(batch, beta=0)
File "/home/.conda/envs/jtnn_pytorch/lib/python2.7/site-packages/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
File "/home/molopt/jtnn/jtprop_vae.py", line 76, in forward
tree_mess, tree_vec, mol_vec = self.encode(mol_batch)
File "/home/molopt/jtnn/jtprop_vae.py", line 60, in encode
mol_vec = self.mpn(mol2graph(smiles_batch))
File "/home/molopt/jtnn/mpn.py", line 75, in mol2graph
agraph[a,i] = b
IndexError: index 6 is out of range for dimension 0 (of size 6)
--> PS: I changed MAXNB in jtnn_dec and jtnn_enc to 32.
maxbernhard commented
Ok so I figured out that the error is related to the input dataaset. Are there any known limitations towards the input dataset?
wengong-jin commented
Hi,
I think you need to increase the MAX_NB value. This variable is defined here
https://github.com/wengong-jin/icml18-jtnn/blob/master/jtnn/mpn.py#L12
maxbernhard commented
Thank you that resolved the issue!