This repo implements seqopt
, a lightweight library enabling users to perform model-based optimization using differentiable nucleotide property models written in PyTorch. Its only dependency is torch>=1.9.0
, so it can be easily installed into your environment via
pip install git+https://github.com/josephvalencia/RNA-design.git
seqopt
addresses the problem of using a trained sequence -> property predictor to identify sequences which the model predicts to satisfy a target value. It supports two basic paradigms for sampling: probabilistic reparameterization and gradient-guided Markov Chain Monte Carlo. Both strategies leverage the gradient of the oracle model with respect to its inputs, while avoiding the pathology of evaluating the model on infeasible (i.e. non one-hot) input.
In this strategy, an auxiliary categorical distribution is defined from which to sample discrete sequences. The distribution parameters are updated in a gradient-ascent-like fashion to optimize the model output function based on the input gradients or function evaluations of sequences sampled from the distribution. seqopt
supports a variety of algorithms for updating the parameters.
- Straight-Through Estimator (STE) (Bengio et al, 2013)
- Softmax STE (Chung et al, 2016)
- Gumbel-softmax STE (Jang et al, 2016)
- REINFORCE/score function estimator
See (Linder and Seelig, 2021) for a TensorFlow implementation of many of these algorithms and experiments on several nucleotide property tasks.
In this strategy, no additional parameters are learned but sequence gradients are used to bias a proposal distribution for MCMC over discrete distributions. Currently, two such approaches are supported.
- Gibbs with Gradients (Grathwohl et. al 2020)
- Discrete Langevin Proposal (Zhang et. al 2022)
To equip a Pytorch sequence model with the ability to design sequences, a user must extend the NucleotideDesigner
base class.
import torch
from abc import ABC, abstractmethod
from typing import Callable, List, Union
class NucleotideDesigner(ABC):
def __init__(self,num_classes,class_dim):
self.num_classes = num_classes
self.class_dim = class_dim
@abstractmethod
def onehot_encode(self,seq : torch.Tensor) -> torch.Tensor:
'''Convert a dense sequence tensor to a one-hot encoding'''
pass
@abstractmethod
def dense_decode(self,seq : torch.Tensor) -> Union[str,List[str]] :
''' Convert a dense sequence tensor to a readable nucleotide sequence'''
pass
@property
@abstractmethod
def oracles(self) -> List[Union[Callable,torch.nn.Module]]:
'''Return a list of differentiable oracles that will be applied to the sequence'''
pass
@abstractmethod
def seed_sequence(self) -> torch.Tensor:
'''Generate a random sequence of a given length'''
pass
seqopt
began as part of a project titled "Extrapolative benchmarking of model-based discrete sampling methods for RNA design", presented at the 2023 Machine Learning in Computational Biology conference. Based on prior works, I implemented basic CNN+LSTM models for predicting ribosome load of 5' UTRs, degradation properties of mRNAs, and toehold switch activities. I evaluated the ability of model-based optimization to produce designs which exceed the property values observed during training, as scored by a more powerful model. Code for training these models and performing sequence optimization using the seqopt
API is located in the MLCB folder. See our extended abstract and poster for further details.
Naive input optimization can drift towards regions of input space where model predictions are inaccurate. For the experiments above, I implemented a simple uncertainty estimation procedure using Evidential Regression (Amini et al. 2019) to permit sampling of sequences with high property values and low uncertainty.
A probably more robust approach is to use a prior density model seqopt
permits the use of multiple sequence oracles to implement such a procedure, this is subject to the availability of suitable likelihood models, and testing for this is ongoing.