/protein-awd-lstm

Protein LMs compatible with TAPE

Primary LanguagePython

Protein-AWD-LSTM

Description

This repo contains implementations of AWD-LSTM and SHA-RNN to be used with TAPE for protein language modeling.
Model code is based on:

Base classes are inherited from TAPE, which needs to be installed.
https://github.com/songlab-cal/tape

Features

Hidden state resetting

These models are trained with stochastic truncated backpropagation through time, where all sequences are concatenated and subsegments are cut from this concatenation to yield training minibatches.

It's unclear whether this is a good idea for proteins, as we would leave it to the model to learn that it should forget everything that came before a SEP token. These implemenation support hidden state resetting, where the model resets the respective positions of the hidden states to 0 after encountering a SEP token.
There is also experimental support for this in SHA-RNN, where in addition the self-attention mask is modified to ignore memory before a SEP. The behaviour is controlled by specifiyng a reset_token_id in the model config.

Memory-mapped datasets

The default implementations load full datasets into memory for truncated BPTT. This doesn't work for large datasets. We use VirtualBatchTruncatedBPTTHdf5Dataset, a class that uses pre-tokenized and pre-concatenated Hdf5 datasets to emulate the stochastic minibatch generation. The Hdf5 files are created with the provided static methods of the class.

Notes

Models are trained with tensors in (sequence_length, batch_size,) format.

The performance of the SHA-RNN implementation was not tested.