scverse/anndata

Utility for loading dask array from SparseDataset

ivirshup opened this issue · 1 comments

Please describe your wishes and possible alternatives to achieve the desired result.

We should have a utility available for creating a sparse dask array from a sparse dataset.

Basic implementation (needs CSC support too)
from scipy import sparse
import numpy as np
import dask.array as da

class CSRCallable:
    """Dummy class to bypass dask checks"""

    def __new__(cls, shape, dtype) -> sparse.csr_matrix:
        if len(shape) == 0:
            shape = (0, 0)
        if len(shape) == 1:
            shape = (shape[0], 0)
        elif len(shape) == 2:
            pass
        else:
            raise ValueError(shape)
        return sparse.csr_matrix(shape, dtype=dtype)


def make_dask_chunk(
    x: ad.experimental.SparseDataset, start: int, end: int
) -> DaskArray:
    import dask.array as da
    from dask import delayed

    def take_slice(x, idx):
        return x[idx]

    return da.from_delayed(
        delayed(take_slice)(x, slice(start, end)),
        dtype=x.dtype,
        shape=(end - start, x.shape[1]),
        meta=CSRCallable,
    )


def sparse_dataset_as_dask(x, stride: int):
    import dask.array as da

    n_chunks, rem = divmod(x.shape[0], stride)

    chunks = []
    cur_pos = 0
    for i in range(n_chunks):
        chunks.append(make_dask_chunk(x, cur_pos, cur_pos + stride))
        cur_pos += stride
    if rem:
        chunks.append(make_dask_chunk(x, cur_pos, x.shape[0]))

    return da.concatenate(chunks, axis=0)

I'm not sure what this should be called, or where it should go. Probably experimental for now.

From #1348 (comment)

cc: @ilan-gold

Closing as duplicate of #1430