seq-lang/seq

Help with adding Seq implementation of FASTQ count and BED coverage

jelber2 opened this issue · 9 comments

Hi,

This might be interest to promote the utility of the Seq language. Heng Li and others are benchmarking different programming languages to perform some standard bioinformatics functions such as counting number of reads and bases in a FASTQ file and calculating coverage between two bed files. I wrote a very, very simple FASTQ counter for Seq (lh3/biofast#15). A bed coverage functions is a little more complicated but would be great given commit 5926893 in the development branch that adds a native BED file parser.

It might be nice to improve upon what I did and add a BED coverage code through a formal pull request.

Best,
Jean

Hey @jelber2, thanks for the suggestions -- that sounds like a great idea! Let me take a look at those benchmarks and see if I can come up with a Seq version.

So the FASTQ implementation looks good -- you can also try FASTQ(..., validate=False) to make it slightly faster. The Seq version is still somewhat slower than C here because in the Seq version there is an additional copy that is done. I think we can look into adding a copy=False option like we have for FASTA.

Here's a first attempt at an implementation for the BED benchmark (confirmed it gives correct results on biofast's test data):

# cgranges implementation adapted from
# https://github.com/lh3/cgranges/blob/master/cpp/IITree.h

type StackCell(_x: int, _k: i32, _w: i32):
    def __init__(self: StackCell, k: int, x: int, w: int) -> StackCell:
        return (x, i32(k), i32(w))

    @property
    def k(self: StackCell):
        return int(self._k)

    @property
    def x(self: StackCell):
        return int(self._x)

    @property
    def w(self: StackCell):
        return int(self._w)

type Interval(st: int, en: int, max: int):
    def __init__(self: Interval, st: int, en: int) -> Interval:
        return (st, en, en)

    def with_max(self: Interval, max: int) -> Interval:
        return (self.st, self.en, max)

    @property
    def start(self: Interval):
        return int(self.st)

    @property
    def end(self: Interval):
        return int(self.en)

class IntervalTree:
    a: list[Interval]
    max_level: int

    def __init__(self: IntervalTree):
        self.a = list[Interval]()
        self.max_level = 0

    def _index_core(a: list[Interval]):
        if not a:
            return -1
        N = len(a)
        i = 0
        last_i = 0  # last_i points to the rightmost node in the tree
        last = 0    # last is the max value at node last_i
        while i < N:  # leaves (i.e. at level 0)
            last_i = i
            ai = a[i]
            last = ai.en
            a[i] = ai.with_max(last)
            i += 2

        k = 1
        while (1 << k) <= N:
            x = 1 << (k - 1)
            i0 = (x << 1) - 1
            step = x << 2

            i = i0
            while i < N:  # traverse all nodes at level k
                el = a[i - x].max  # max value of left child
                er = a[i + x].max if i + x < N else last  # of the right child
                e = a[i].en
                e = e if e > el else el
                e = e if e > er else er
                a[i] = a[i].with_max(e)  # set the max value for node i
                i += step

            last_i = last_i - x if last_i >> k & 1 else  last_i + x  # last_i now points to the parent of the original last_i
            if last_i < N and a[last_i].max > last:  # update last accordingly
                last = a[last_i].max
            k += 1
        return k - 1

    def add(self: IntervalTree, start: int, end: int):
        self.a.append(Interval(start, end))

    def index(self: IntervalTree):
        self.a.sort()
        self.max_level = IntervalTree._index_core(self.a)

    def overlap(self: IntervalTree, start: int, end: int):
        a = self.a
        st = start
        en = end
        stack = __array__[StackCell](64)
        t = 0

        stack[t] = StackCell(self.max_level, (1 << self.max_level) - 1, 0)  # push the root; this is a top down traversal
        t += 1

        while t:  # the following guarantees that numbers in out[] are always sorted
            t -= 1
            z = stack[t]
            if z.k <= 3:  # we are in a small subtree; traverse every node in this subtree
                i0 = z.x >> z.k << z.k
                i1 = i0 + (1 << (z.k + 1)) - 1
                if i1 >= len(a):
                    i1 = len(a)
                i = i0
                while i < i1 and a[i].st < en:
                    if st < a[i].en:  # if overlap, append to out[]
                        yield a[i]
                    i += 1
            elif z.w == 0:  # if left child not processed
                y = z.x - (1 << (z.k - 1))  # the left child of z.x; NB: y may be out of range (i.e. y>=a.size())
                stack[t] = StackCell(z.k, z.x, 1)  # re-add node z.x, but mark the left child having been processed
                t += 1
                if y >= len(a) or a[y].max > st:  # push the left child if y is out of range or may overlap with the query
                    stack[t] = StackCell(z.k - 1, y, 0)
                    t += 1
            elif z.x < len(a) and a[z.x].st < en:  # need to push the right child
                if st < a[z.x].en:  # test if z.x overlaps the query; if yes, append to out[]
                    yield a[z.x]
                stack[t] = StackCell(z.k - 1, z.x + (1 << (z.k - 1)), 0)  # push the right child
                t += 1

    def __len__(self: IntervalTree):
        return len(self.a)

    def __bool__(self: IntervalTree):
        return len(self) > 0

    def __getitem__(self: IntervalTree, idx: int):
        return self.a[idx]

    def start(self: IntervalTree, idx: int):
        return self.a[idx].st

    def end(self: IntervalTree, idx: int):
        return self.a[idx].en

from sys import argv
from time import timing

with timing('bed coverage'):
    bed = dict[str, IntervalTree]()

    for record in BED(argv[1]):
        if record.chrom not in bed:
            bed[record.chrom] = IntervalTree()
        bed[record.chrom].add(record.chrom_start, record.chrom_end)

    for tree in bed.values():
        tree.index()

    for record in BED(argv[2]):
        if record.chrom not in bed:
            print f'{record.chrom}\t{record.chrom_start}\t{record.chrom_end}\t0\t0'
        else:
            cov, cov_st, cov_en, n = 0, 0, 0, 0
            st1, en1 = record.chrom_start, record.chrom_end
            for item in bed[record.chrom].overlap(st1, en1):
                n += 1
                # calcualte overlap length/coverage
                st0, en0 = item.start, item.end
                if st0 < st1: st0 = st1
                if en0 > en1: en0 = en1
                if st0 > cov_en: # no overlap with previous found intervals
                    # set coverage to current interval
                    cov += cov_en - cov_st
                    cov_st, cov_en = st0, en0
                elif cov_en < en0: cov_en = en0  #overlap with previous found intervals
                       #only need to check end, since 'out' is a sorted list
            cov += cov_en - cov_st
            #  print chrom, start, end, count, # of coverage nt
            print f'{record.chrom}\t{record.chrom_start}\t{record.chrom_end}\t{n}\t{cov}'

This one is currently noticeably slower then the C version on the biofast repo (but I'd expect it to be a lot faster than the Python/PyPy versions -- haven't tried it though). I think there are a few things we can do here to speed it up; ideally we can add an improved IntervalTree to the standard library that uses array instead of list, and avoid several duplicate dictionary lookups in the actual query code (we're looking into adding some compiler optimizations for this, since duplicate dictionary lookups are a pretty common pattern in Python but can have a big performance impact).

Edit: copy=False, validate=False on the BED parser makes it about 50% faster.

Oh wow! Interesting, so with copy=False and validate=False, it is on par with the C version?

Edit: I tried

import sys
if len(sys.argv) == 1:
        print("Usage: fqcnt <in.fq.gz>")
        sys.exit(0)
fn = sys.argv[1]
n, slen, qlen = 0, 0, 0
for r in FASTQ(fn,validate=False):
    n += 1
    slen += len(r.read)
    qlen += len(r.qual)
print f'{n}\t{slen}\t{qlen}'

Which is twice as fast as validate=True, but adding copy=False results in the following on Seq 0.9.7

ValueError: cannot iterate over FASTQ records with copy=False

raised from: __iter__
/home/jelber2/.seq/stdlib/bio/fastq.seq:101:13
Aborted (core dumped)

I just added the IntervalTree to the standard library with ecc5ee6. Benchmarks are also under the test/bench folder now:

That commit also makes the bedcov benchmark substantially faster by fixing some inefficiencies in the BED parser. I'm still working on making these even faster -- including proper copy=False for FASTQ!

Oh cool! Will there be a new release soon of Seq maybe 0.9.8?

Yes we're working on the new release now actually. It'll most likely be 0.10, as we're planning on introducing an improved type system that further closes the gap with Python.

Sounds great!

Just checking if copy=False has been enabled or not in 0.9.10? As far as I can tell it has not been enabled.

copy=False is added in #132 that is soon to be merged to the main branch.