Inifinite decoding
vandedok opened this issue · 5 comments
Hello,
I am using JTVAE from your repo and faced some problems. I was doing random search over the latent space, but for some points the decoding is not stopping. It seems that the problem is in the in the enum_assemble
function from chemutils.py file (weighted-retraining/weighted_retraining/chem/jtnn/chemutils.py
): the internal recursive search
function either doesn't converge or converge too slow.
This is the code to reproduce the issue ( I am assuming your conda environment present and the preprocessing is done):
from pathlib import Path
import torch
from weighted_retraining.chem.chem_data import Vocab
from weighted_retraining.chem.chem_model import JTVAE
weighted_retraining_dir = Path("../../weighted-retraining/")
with open(weighted_retraining_dir/"data/chem/zinc/orig_model/vocab.txt") as f:
vocab = Vocab([x.strip() for x in f.readlines()])
pretrained_model_file = str(weighted_retraining_dir / "assets/pretrained_models/chem.ckpt")
model = JTVAE.load_from_checkpoint(
pretrained_model_file, vocab=vocab
)
model.beta = model.hparams.beta_final
bad_latents = torch.tensor([[ -6.2894, -45.1619, 11.9765, 11.6767, 37.5106, -24.5908, -25.9559,
40.9180, 18.4495, -30.9735, -10.9526, -31.6441, -49.6980, -36.1106,
12.7674, 6.1417, -44.0838, -34.6051, -9.2435, 47.8085, 41.7193,
-44.4102, 15.3359, -38.5631, 7.2546, -48.9917, 16.5505, -45.4565,
-49.4582, 11.6730, 13.2594, -37.0152, 39.9500, -39.3020, -16.2288,
23.3959, -36.6568, -48.8145, 13.4714, 19.7008, 30.5797, -42.0284,
-28.3188, -29.0985, 18.7675, -7.5038, 10.2781, 1.0429, -24.5770,
-15.5115, 10.9733, -18.1378, -34.5497, -25.7164, -21.9990, 14.0688]])
with torch.no_grad():
smiles = model.decode_deterministic(bad_latents)
Did you face anything similar? Do you know what can be done in such situations?
PS Thanks for your JTVAE implementation, it's the most convenient one I have found
Hello, happy to hear that you found this implementation useful. I did not encounter this particular issue but am not surprised by it. I don't have a fix available unfortunately. My best suggestion would just be to check how many solutions enum_assemble
iterates through, and perhaps force it to terminate after a certain maximum number of iterations. The effective behaviour here would be to return None
at these points rather than just hanging for a long time. Would this be a good workaround for you?
Thank you for the reply!
Do you mean that I need to check how many search iterations does the search
function inside 'enum_assemble` does?
Anyways, this may work. However I wonder why didn't you faced the problem -- it didn't take that many iterations to find these points. Also ti seems that bayesian optimisation find those faster than random search.
Btw, sometimes I do get None as decoding result. However I didn't check what causes that.
Thank you for the reply!
Do you mean that I need to check how many search iterations does the
search
function inside 'enum_assemble` does?Anyways, this may work. However I wonder why didn't you faced the problem -- it didn't take that many iterations to find these points. Also ti seems that bayesian optimisation find those faster than random search.
Btw, sometimes I do get None as decoding result. However I didn't check what causes that.
Yes that is what I meant. Maybe you are finding these points because the python libraries have been updated? I last ran this code ~2 years ago.
Let me know if this solution is helpful and I will close the issue. However, I don't really want to make changes to this codebase at the moment since I think it is important to be able to use it to reproduce the results of our paper, and changes to the code may change the behaviour.
Okay, I tried to limit the search calls in a crude way:
def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[], max_search_calls=20):
...
...
n_search_calls = 0
def search(cur_amap, depth):
nonlocal n_search_calls
n_search_calls += 1
...
...
for new_amap in candidates:
if n_search_calls < max_search_calls:
search(new_amap, depth + 1)
It sort of worked: now the encoding ends on meaningful time for the latents I've presented, but the output looks like garbage:
However rdkit can even compute the LogP for this thing.
Would you like me to create a pull request or we just live it be?
Thanks for the code snippet @vandedok ! If the search does not complete then it makes sense that a sub-optimal molecule may be returned, as you've demonstrated. I think this is just a limitation of the JT-VAE.
My suggestion would just be to leave this, instead of submitting a PR. I will make a note of it on the README for this project. Feel free to re-open the issue if you feel that is appropriate.