cambridge-mlg/weighted-retraining

Inifinite decoding

vandedok opened this issue · 5 comments

Hello,

I am using JTVAE from your repo and faced some problems. I was doing random search over the latent space, but for some points the decoding is not stopping. It seems that the problem is in the in the enum_assemble function from chemutils.py file (weighted-retraining/weighted_retraining/chem/jtnn/chemutils.py): the internal recursive search function either doesn't converge or converge too slow.

This is the code to reproduce the issue ( I am assuming your conda environment present and the preprocessing is done):

from pathlib import Path
import torch
from weighted_retraining.chem.chem_data import Vocab
from weighted_retraining.chem.chem_model import JTVAE

weighted_retraining_dir = Path("../../weighted-retraining/")

with open(weighted_retraining_dir/"data/chem/zinc/orig_model/vocab.txt") as f:
    vocab = Vocab([x.strip() for x in f.readlines()])

pretrained_model_file = str(weighted_retraining_dir / "assets/pretrained_models/chem.ckpt")

model = JTVAE.load_from_checkpoint(
    pretrained_model_file, vocab=vocab
)

model.beta = model.hparams.beta_final 

bad_latents = torch.tensor([[ -6.2894, -45.1619,  11.9765,  11.6767,  37.5106, -24.5908, -25.9559,
          40.9180,  18.4495, -30.9735, -10.9526, -31.6441, -49.6980, -36.1106,
          12.7674,   6.1417, -44.0838, -34.6051,  -9.2435,  47.8085,  41.7193,
         -44.4102,  15.3359, -38.5631,   7.2546, -48.9917,  16.5505, -45.4565,
         -49.4582,  11.6730,  13.2594, -37.0152,  39.9500, -39.3020, -16.2288,
          23.3959, -36.6568, -48.8145,  13.4714,  19.7008,  30.5797, -42.0284,
         -28.3188, -29.0985,  18.7675,  -7.5038,  10.2781,   1.0429, -24.5770,
         -15.5115,  10.9733, -18.1378, -34.5497, -25.7164, -21.9990,  14.0688]])

with torch.no_grad():
    smiles = model.decode_deterministic(bad_latents)

Did you face anything similar? Do you know what can be done in such situations?

PS Thanks for your JTVAE implementation, it's the most convenient one I have found

Hello, happy to hear that you found this implementation useful. I did not encounter this particular issue but am not surprised by it. I don't have a fix available unfortunately. My best suggestion would just be to check how many solutions enum_assemble iterates through, and perhaps force it to terminate after a certain maximum number of iterations. The effective behaviour here would be to return None at these points rather than just hanging for a long time. Would this be a good workaround for you?

Thank you for the reply!

Do you mean that I need to check how many search iterations does the search function inside 'enum_assemble` does?

Anyways, this may work. However I wonder why didn't you faced the problem -- it didn't take that many iterations to find these points. Also ti seems that bayesian optimisation find those faster than random search.

Btw, sometimes I do get None as decoding result. However I didn't check what causes that.

Thank you for the reply!

Do you mean that I need to check how many search iterations does the search function inside 'enum_assemble` does?

Anyways, this may work. However I wonder why didn't you faced the problem -- it didn't take that many iterations to find these points. Also ti seems that bayesian optimisation find those faster than random search.

Btw, sometimes I do get None as decoding result. However I didn't check what causes that.

Yes that is what I meant. Maybe you are finding these points because the python libraries have been updated? I last ran this code ~2 years ago.

Let me know if this solution is helpful and I will close the issue. However, I don't really want to make changes to this codebase at the moment since I think it is important to be able to use it to reproduce the results of our paper, and changes to the code may change the behaviour.

Okay, I tried to limit the search calls in a crude way:

def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[], max_search_calls=20):
    ...
    ...
    n_search_calls = 0
    def search(cur_amap, depth):
        nonlocal n_search_calls
        n_search_calls += 1
        ...
        ...
        for new_amap in candidates:
            if n_search_calls < max_search_calls:
                search(new_amap, depth + 1)

It sort of worked: now the encoding ends on meaningful time for the latents I've presented, but the output looks like garbage:
Screenshot from 2022-10-28 18-09-27

However rdkit can even compute the LogP for this thing.

Would you like me to create a pull request or we just live it be?

Thanks for the code snippet @vandedok ! If the search does not complete then it makes sense that a sub-optimal molecule may be returned, as you've demonstrated. I think this is just a limitation of the JT-VAE.

My suggestion would just be to leave this, instead of submitting a PR. I will make a note of it on the README for this project. Feel free to re-open the issue if you feel that is appropriate.