TL;DR
- A family of models that "understand" small organic molecules (SMILES), their basic properties (molecular weight, QED, SAS, TPSA, CLogP, ...), and similarities between molecules (Tanimoto over ECFC4).
- Chemlactica-125M 🤗 and -1.3B 🤗 are trained on top of Meta's Galactica models.
- Chemma-2B 🤗 is built on top of Google's Gemma-2B.
- All models are trained on 40B tokens covering 100M+ molecules from PubChem. Check the corpus at 🤗.
- A prompt like
</s>[SAS]2.25[/SAS][SIMILAR]CC(=O)OC1=CC=CC=C1C(=O)O 0.62[/SIMILAR][START_SMILES]
will generate a molecule that has ~2.25 SAS score and has ~0.62 similarity score to the given molecule. - The models can be easily tuned to perform property prediction (~0.3 RMSE on FreeSolv from MoleculeNet).
- The models wrapped into a genetic-like optimization algorithm beat all molecular optimization benchmarks we tried.
- Practical Molecular Optimization
- 17.5 vs 16.2 (previous SOTA: Genetic-guided GFlowNets).
- Optimization for docking with AutoDock Vina
- 3-4x fewer oracle calls for generating 100 good molecules than previous SOTA (Beam Enumeration).
- QED optimization from the RetMol paper
- 99% success rate with 10K oracle calls with Chemlactica-125M (vs. 96% with 50K calls of the original paper).
- Practical Molecular Optimization
- Read the details in the paper Small Molecule Optimization with Large Language Models.
We are looking forward to the community utilizing these models for solving various problems in molecular design.
Fine tuning the galactica models on chemistry data from PubChem.
- Python 3.11
- conda
conda create -n ChemLactica python=3.11 -y -f environment.yml
conda activate chemlactica
Instructions coming soon...
Instructions coming soon...
Running the Optimization Algorithm requires two steps:
Step 1. Define the Oracle, which is responsible to evaluate the oracle scores for the given molecules. Below is presented the Oracle implementation scheme.
class ExampleOracle:
def __init__(self, ...):
# maximum number of oracle calls to make
self.max_oracle_calls: int = ...
# the frequence with which to log
self.freq_log: int = ...
# the buffer to keep track of all unique molecules generated
self.mol_buffer: Dict = ...
# the maximum possible oracle score or an upper bound
self.max_possible_oracle_score: float = ...
def __call__(self, molecules):
"""
Evaluate and return the oracle scores for molecules. Log the intermediate results if necessary.
"""
...
return oracle_scores
@property
def finish(self):
"""
Specify the stopping condition for the optimization process.
"""
...
return stopping_condition
Step 2. Define the hyperparameters for the optimization algorithm (such as the pool size, number of similar molecules to have in the prompts, sampling temperature, etc.) in a .yaml file.
# model_hparams.yaml
checkpoint_path: /path/to/model_dir
tokenizer_path: /path/to/tokenizer_dir
... optimization algorithm hyperparameter (pool size, number of similar molecules to use, etc.) ...
generation_config:
... molecule generation hyperparameters ...
strategy: [rej-sample-v2] # or use [default] for not performing the fine-tuning step during the optimization.
rej_sample_config:
... fine tuning hyperparameters ...
Call the optimize function.
from chemlactica.mol_opt.optimization import optimize
# Load config
config = yaml.safe_load(open(path_to_yaml_config))
# Load the model and the tokenizer
model = AutoModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)
# Create Oracle
oracle = ExampleOracle(...)
# Call the optimize function to optimize against the defined oracle
optimize(
model, tokenizer,
oracle, config
)
Refer to example_run.py for a full working example of an optimization run. For more complex examples refer to the ChemlacticaTestSuit repository mol_opt and retmol directories.
The test for running the a small sized model with the same architecture as galactica on a small set of data is located at /tests/precommit_test.py and can be called as follows:
python -m unittest precommit_test.py
This test is also run as part of the CI pipeline on the main branch on a public github runner.