/ChemLactica

Fine-tuning Galactica and Gemma to operate on SMILES. Integrates into a molecular optimization algorithm.

Primary LanguageJupyter Notebook

Chemlactica / Chemma: Large Language Models for Small Molecules

TL;DR

  • A family of models that "understand" small organic molecules (SMILES), their basic properties (molecular weight, QED, SAS, TPSA, CLogP, ...), and similarities between molecules (Tanimoto over ECFC4).
  • Chemlactica-125M 🤗 and -1.3B 🤗 are trained on top of Meta's Galactica models.
  • Chemma-2B 🤗 is built on top of Google's Gemma-2B.
  • All models are trained on 40B tokens covering 100M+ molecules from PubChem. Check the corpus at 🤗.
  • A prompt like </s>[SAS]2.25[/SAS][SIMILAR]CC(=O)OC1=CC=CC=C1C(=O)O 0.62[/SIMILAR][START_SMILES] will generate a molecule that has ~2.25 SAS score and has ~0.62 similarity score to the given molecule.
  • The models can be easily tuned to perform property prediction (~0.3 RMSE on FreeSolv from MoleculeNet).
  • The models wrapped into a genetic-like optimization algorithm beat all molecular optimization benchmarks we tried.
  • Read the details in the paper Small Molecule Optimization with Large Language Models.

We are looking forward to the community utilizing these models for solving various problems in molecular design.

Table of contents

Description

Fine tuning the galactica models on chemistry data from PubChem.

Prerequisites

  • Python 3.11
  • conda

Installation

conda create -n ChemLactica python=3.11 -y -f environment.yml
conda activate chemlactica

Usage

Pretraining

Instructions coming soon...

Fine-tuning

Instructions coming soon...

Molecular Optimization 🎯

Running the Optimization Algorithm requires two steps:

Step 1. Define the Oracle, which is responsible to evaluate the oracle scores for the given molecules. Below is presented the Oracle implementation scheme.

class ExampleOracle:
    def __init__(self, ...):
        # maximum number of oracle calls to make
        self.max_oracle_calls: int = ...

        # the frequence with which to log
        self.freq_log: int = ...

        # the buffer to keep track of all unique molecules generated
        self.mol_buffer: Dict = ...

        # the maximum possible oracle score or an upper bound
        self.max_possible_oracle_score: float = ... 

    def __call__(self, molecules):
        """
            Evaluate and return the oracle scores for molecules. Log the intermediate results if necessary.
        """
        ...
        return oracle_scores

    @property
    def finish(self):
        """ 
            Specify the stopping condition for the optimization process.
        """
        ...
        return stopping_condition

Step 2. Define the hyperparameters for the optimization algorithm (such as the pool size, number of similar molecules to have in the prompts, sampling temperature, etc.) in a .yaml file.

# model_hparams.yaml
checkpoint_path: /path/to/model_dir
tokenizer_path: /path/to/tokenizer_dir

... optimization algorithm hyperparameter (pool size, number of similar molecules to use, etc.) ...

generation_config:
  ... molecule generation hyperparameters ...

strategy: [rej-sample-v2] # or use [default] for not performing the fine-tuning step during the optimization.

rej_sample_config:
    ... fine tuning hyperparameters ...

Call the optimize function.

from chemlactica.mol_opt.optimization import optimize

# Load config
config = yaml.safe_load(open(path_to_yaml_config))

# Load the model and the tokenizer
model = AutoModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# Create Oracle
oracle = ExampleOracle(...)

# Call the optimize function to optimize against the defined oracle
optimize(
    model, tokenizer,
    oracle, config
)

Refer to example_run.py for a full working example of an optimization run. For more complex examples refer to the ChemlacticaTestSuit repository mol_opt and retmol directories.

Tests

The test for running the a small sized model with the same architecture as galactica on a small set of data is located at /tests/precommit_test.py and can be called as follows:

python -m unittest precommit_test.py

This test is also run as part of the CI pipeline on the main branch on a public github runner.