evolutionaryscale/esm

Availability of UniRef validation set

Closed this issue · 2 comments

Hi! Thank you for the amazing work and the user-friendly codebase!

I would like to reproduce Figure S8, but unfortunately, I couldn’t find the validation set of proteins from UniRef on the Hugging Face Hub page. Is this set please available anywhere, for example as a list of UniProt IDs?

image image

Hi Roman!

Our validation set which has not been released was constructed from UniRef clustered at the 70% sequence identity and purged as per section "A.2.1.7. Purging of Validation Sequences" in the manuscript. The keyword labels used to assess function prediction were made from InterPro annotations from UniProt as per "A.2.1.4. Functional Labels". For this function prediction evaluation, the InterPro annotations were converted to function keyword via the mapping provided in the repository.

For the computation of mean Average Precision (mAP). We computed both a macro average and a micro average using TorchMetrics' MultilabelAveragePrecision and the respective average keyword. The predictions and labels are both given on a per-residue basis and the MultilabelAveragePrecision handles naturally the case when there are multiple labels and/or predictions in a single location.

For micro-averaged mAP we computed the mAP for each individual sample in the validation set, then averaged each value uniformly. Computing the macro-averaged mAP involves computing the mAP for each keyword individually over the entire validation set and then averaging the keyword-mAPs uniformly. It is impractical to perform this computation over all ~40k keywords both due to paucity labels for uncommon keywords and memory usage. So we only computed the average over the top 1000 most common keywords, excluding uninformative keywords. I provide the full list here: top_keywords.txt

The code that we used to produce the function keywords predictions follows closely the example given under "Function Prediction" in esm/examples/raw_forwards.py with one exception:

We noticed that raw function predictions are often "noisy" similar to segmentation boundaries in image segmentation models. We took measures to smooth predictions and for this figure we smoothed function-token logits using a Dense Conditional Random Field from pydensecrf. I've provided the code to do this below. In the repository we released a simpler smoothing approach of simply dropping short predicted annotations and merging small gaps (see annotation_min_length and annotation_gap_merge_max), which seemed to have essentially the same effect as the CRF smoothing. We did not observe either smoothing method having a significant impact on final function prediction validation metrics.

import numpy as np
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import create_pairwise_gaussian, unary_from_softmax

def crf_smooth_function_token_logits(
    function_token_logits: torch.Tensor,
    length_scale: float = 3.0,
    compatibility: float = 3.0,
    num_iterations: int = 10,
) -> torch.Tensor:
    """Smooth function token logits with a Conditional Random Field (CRF).

    See https://arxiv.org/abs/1210.5644 for details.

    Args:
        function_token_logits: <float>[length, depth, vocab] function token logits.
        length_scale: affects the width of locality sensitivtiy of the CRF.
        compatibility: parameter affecting smoothness.
        num_iterations: number of CRF iterations to run the model.
    Returns:
        <float>[length, depth, vocab] function token logits processed with the CRF.
    """
    length, depth, vocab = function_token_logits.shape
    p = torch.softmax(function_token_logits, dim=-1).cpu().float().numpy()

    pairwise_energy = create_pairwise_gaussian((length_scale,), (length,))

    crf_logits = torch.zeros_like(function_token_logits)
    for i in range(depth):
        crf = dcrf.DenseCRF(length, vocab)

        unary = unary_from_softmax(p[:, i, :].T).copy()  # (vocab, length)
        crf.setUnaryEnergy(unary)

        crf.addPairwiseEnergy(pairwise_energy, compat=compatibility)
        Q = crf.inference(num_iterations)
        crf_logits[:, i, :] = torch.log(torch.tensor(np.array(Q)).T)

    return crf_logits

Thank you very much for the detailed information.