microsoft/evodiff

Missing data files: lengths_and_offsets.npz, consensus.fasta, and splits.json

Linzy19 opened this issue · 3 comments

Hello,
While trying to run train.py, I noticed that the metadata and ds_train mentioned in lines 164-165 require the three files: lengths_and_offsets.npz, consensus.fasta, and splits.json. However, I couldn't find these files in the data folder or on Uniprot. Could you provide these files or guide me on how to obtain or generate them?
Thank you!

Dear @yangkky,

I am willing to fine tune the 38M parameter OADM model by continuing its training on all sequences that we can find on RCSB database.
I could extract them as a FASTA format.
I guess that you had to download the Uniref50 sequences as a FASTA file as well, but then how did you and your team generated the following two files :

'splits.json'
'lengths_and_offsets.npz'

Best,

Loris

Not super elegant, but this is the Snakefile I used to get the consensus sequences, then the lengths and offsets, and then hold out sequences that are similar to things in John Ingraham's CATH test set.

import json
import math
from collections import Counter
import subprocess
from tqdm import tqdm

import numpy as np
import pandas as pd

rule all:
    input:
        "lengths_and_offsets.npz",
        "splits.json"

rule flatten_fasta:
    input:
        "uniref50.fasta"
    output:
        "consensus.fasta"
    run:
        result = subprocess.run(['wc', '-l', input[0]], stdout=subprocess.PIPE)
        length = int(result.stdout.decode('utf-8').split(' ')[0]) // 2
        with tqdm(total=length) as pbar:
            with open(input[0], 'r') as f_in, open(output[0], 'w') as f_out:
                seq = ''
                for line in f_in:
                    if line[0] == '>':
                        if len(seq) > 0:
                            f_out.write(seq + '\n')
                            seq = ''
                            pbar.update(1)
                        f_out.write(line)
                    else:
                        seq += line[:-1]


rule get_offsets:
    input:
        "consensus.fasta"
    output:
        "lengths_and_offsets.npz"
    run:
        results = {}
        results['name_offsets'] = []
        results['seq_offsets'] = []
        results['ells'] = []
        result = subprocess.run(['wc', '-l', input[0]], stdout=subprocess.PIPE)
        length = int(result.stdout.decode('utf-8').split(' ')[0]) // 2
        with tqdm(total=length) as pbar:
            with open(input[0], 'r') as f:
                results['name_offsets'].append(f.tell())
                line = f.readline()
                while line:
                    if line[0] != '>':
                        results['name_offsets'].append(f.tell())
                        results['ells'].append(len(line[:-1]))
                    else:
                        results['seq_offsets'].append(f.tell())
                        pbar.update(1)
                    line = f.readline()
        results['ells'].append(len(line[:-1]))
        results['name_offsets'] = np.array(results['name_offsets'])
        results['seq_offsets'] = np.array(results['seq_offsets'])
        results['ells'] = np.array(results['ells'])
        np.savez_compressed(output[0], **results)

rule make_db:
    input:
        "uniref50.fasta"
    output:
        "db/targetDB"
    shell:
        """
        mmseqs createdb {input} {output} --shuffle false
        """

rule make_index:
    input:
        "db/targetDB"
    output:
        directory("db/target_index")
    shell:
        """
        mmseqs createindex {input} {output}
        """

rule get_ingraham_data:
    output:
        "cath/chain_set.jsonl",
        "cath/chain_set_splits.json"
    shell:
        """
        wget -P cath http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set.jsonl
        wget -P cath http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set_splits.json
        """

rule make_test_fasta:
    input:
        "cath/chain_set.jsonl",
        "cath/chain_set_splits.json"
    output:
        "cath/test.fasta"
    run:
        with open(input[1]) as f:
            splits = json.load(f)
        test_names = splits['test']
        with tqdm(total=len(test_names)) as pbar:
            with open(input[0]) as f, open(output[0], 'w') as f2:
                for line in f:
                    chain = json.loads(line)
                    if chain['name'] in test_names:
                        f2.write('>' + chain['name'] + '\n')
                        f2.write(chain['seq'] + '\n')
                        pbar.update(1)

rule easy_search:
    input:
        "cath/test.fasta",
        "db/targetDB",
        "db/target_index"
    output:
        "report.m8"
    run:
        subprocess.run(['mmseqs', 'easy-search',
                    input[0],
                    input[1],
                    output[0],
                    input[2],
                    '-s', '1',
                    '--format-output', 'query,target,raw,pident,nident,qlen,alnlen',
                    '--cov-mode', '2',
                    '-c', '0.8'])

rule split:
    input:
        "report.m8",
        "uniref50.fasta"
    output:
        "splits.json"
    run:
        np.random.seed(0)
        # parse through results
        df = pd.read_csv(input[0], header=None, delimiter='\t')[[0, 1]]
        df.columns = ['qname', 'tname']
        all_names = []
        splits = {'train': [], 'test': [], 'valid': [], 'rtest': []}
        print('Getting names...')
        with open(input[1]) as f:
            for line in f:
                if line[0] == '>':
                    all_names.append(line[1:-1].split(' ')[0])
        n = len(all_names)
        print('Splitting...')
        tnames = set(df.tname.values)
        for i in tqdm(range(n)):
            if all_names[i] in tnames:
                splits['test'].append(i)
                tnames.remove(all_names[i])
            elif np.random.random() < 5e-3:
                splits['rtest'].append(i)
            elif np.random.random() < 2e-3:
                splits['valid'].append(i)
            else:
                splits['train'].append(i)
        for k in splits:
            print(k + ': %d sequences' %len(splits[k]))
        with open(output[0], 'w') as f:
            json.dump(splits, f)