scverse/anndata

`read_elem_as_dask`

ivirshup opened this issue · 8 comments

Please describe your wishes and possible alternatives to achieve the desired result.

We should probably add some read_elem_as_dask methods to make reading stored objects from as dask arrays easier.

This largely needs special casing due to:

  • Sparse arrays
  • Dask not handling h5py object in the process based mode: #1105

I think we can skip dataframe support for now. It may be worth thinking about how to read objects into GPU memory as well.

I'm not completley set on the API, so input welcome. My initial thought is that we could just have a smaller registry of methods which only supports reads for dense and sparse arrays.

cc: @ilan-gold

Starting to remember why I was initially against this kind of thing.

We probably also want:

  • read in as cupy
  • read in as cupy wrapped by dask

So this becomes three functions quite quickly. But also, I think there's a ton of value in expositing this, so maybe we can start here and transition to a better api in the future.

@ivirshup I agree, I think that we kind of have to start somewhere. I think the fact that we need three functions won't change internally, so exposing it via experimental and giving us a chance to iron out kinks could be good. I understand the maintenance burden, but since they will be experimental, we can always yank them from the public API.

Looking forward to these changes!

Would this support arbitrary chunk sizes of sparse matrixes with zarr+dask?

I am looking for a solution that would allow loading [random 50 cells, random block of 400 variables] from 50 layers quickly as part of DataLoader (scvi-tools) without entirely loading all 50 sparse matrixes into memory. The anndata object I have has few obs (2k) but lots of vars (500k+), 50 layers and ~20 of other obs/obsm/var/varm loaded as part of minibatch for the same sample of cells and variables. Everything except layers and adata.X is fine to load into memory in full.

In particular, it would be nice to be able to do the following:

def read_w_sparse_dask_layers(
    group: h5py.Group, obs_chunk: int = 1000, axis: int = 0
) -> ad.AnnData:
    return ad.AnnData(
        X=sparse_dataset_as_dask(sparse_dataset(group["X"]), obs_chunk, axis=axis),
        **{
            k: {
                key: sparse_dataset_as_dask(sparse_dataset(dat), obs_chunk, axis=axis)
                for key, dat in group[k].items()
            }
            if k in group
            else {}
            for k in ["layers"]
        },
        **{
            k: read_elem(group[k]) if k in group else {}
            for k in ["obs", "var", "uns", "obsp", "varp", "obsm", "varm"]
        }
    )

using the code @ivirshup shared with me during scverse hackathon:

def csr_callable(shape: tuple[int, int], dtype) -> sparse.csr_matrix:
    if len(shape) == 0:
        shape = (0, 0)
    if len(shape) == 1:
        shape = (shape[0], 0)
    elif len(shape) == 2:
        pass
    else:
        raise ValueError(shape)

    return sparse.csr_matrix(shape, dtype=dtype)


class CSRCallable:
    """Dummy class to bypass dask checks"""

    def __new__(cls, shape, dtype):
        return csr_callable(shape, dtype)


def make_dask_chunk(x, start: int, end: int) -> da.Array:  #: "SparseDataset",
    def take_slice(x, idx):
        return x[idx]

    return da.from_delayed(
        delayed(take_slice)(x, slice(start, end)),
        dtype=x.dtype,
        shape=(end - start, x.shape[1]),
        meta=CSRCallable,
    )


def make_dask_var_chunk(x, start: int, end: int) -> da.Array:  #: "SparseDataset",
    def take_slice(x, idx):
        return x[:, idx]

    return da.from_delayed(
        delayed(take_slice)(x, slice(start, end)),
        dtype=x.dtype,
        shape=(x.shape[0], end - start),
        meta=CSRCallable,
    )


def sparse_dataset_as_dask(x, stride: int, axis: int = 0):
    n_chunks, rem = divmod(x.shape[axis], stride)

    if axis == 0:
        make_dask_chunk_func = make_dask_chunk
    elif axis == 1:
        make_dask_chunk_func = make_dask_var_chunk
    else:
        raise ValueError("axis must be 0 or 1")

    chunks = []
    cur_pos = 0
    for i in range(n_chunks):
        chunks.append(make_dask_chunk_func(x, cur_pos, cur_pos + stride))
        cur_pos += stride
    if rem:
        chunks.append(make_dask_chunk_func(x, cur_pos, x.shape[axis]))

    return da.concatenate(chunks, axis=axis)

This code is ~1000 times slower at creating minibatches than reading from RAM (2 min instead of a few ms per batch).

This code is ~1000 times slower at creating minibatches than reading from RAM (2 min instead of a few ms per batch).

This makes some sense because RAM is, well, RAM. Even so, zarr + dask make not be the solution. HDF5 is usually memory mapped, which is the secret sauce to making random access fast. Random access with dask is not amazing.

Would this support arbitrary chunk sizes of sparse matrixes with zarr+dask?

Yes it should, but chunking the zarr store on disk is also important!

This makes some sense because RAM is, well, RAM.

Still, 1000x slower speed makes this approach completely unusable and I was told to expect only ~3-4x slowdown (even 10x would be still usable).

Even so, zarr + dask make not be the solution. HDF5 is usually memory mapped, which is the secret sauce to making random access fast. Random access with dask is not amazing.

Does this mean that you recommend using normal h5ad instead of zarr+dask?
I see that if I simply use anndata.read_h5ad(..., backed='r') all adata.layers slots are loaded into memory. Maybe I am using something incorrectly?

but chunking the zarr store on disk is also important!

How are the sparse matrixes in layers currently chunked when using the code above (both zarr and h5ad)?
My understanding is that sparse matrixes are not affected by the chunking described in this tutorial https://anndata.readthedocs.io/en/latest/tutorials/notebooks/%7Bread%2Cwrite%7D_dispatched.html and I see that they are read in full when I use the code from the tutorial.

Given that I am using batches of vars (eg 0:200, then 15200:15400, then 1200:1400, ...) for a relatively small number of obs (random 50/1600), would it make sense to save the transposed object and then transpose back after loading?

random 50 cells, random block of 400 variables

For random access you're going to want to just use the SparseDataset class onto of uncompressed hdf5 and a nice SSD sitting next to your compute. If you really want random variables, you're going to need to choose whether you want to load all cells for each variable or all variables for each cell.

This in theory should be ~10x slower than doing the equivalent in memory (given fancy SSD hardware), but in practice will probably be significantly slower than that.

I was told to expect only ~3-4x slowdown

Was this me? If it was I would have been referring specifically to using dask vs. SparseDataset (e.g. disk vs disk, not disk vs memory), and even then it can be situational.

Given that I am using batches of vars (eg 0:200, then 15200:15400, then 1200:1400, ...) for a relatively small number of obs (random 50/1600), would it make sense to save the transposed object and then transpose back after loading?

Just to understand, is it a different set of obs per var sample, or is it always the same?

If it's always the same, just load those obs into memory then select from them. For different it would probably depend on the total number of observations you have. Given the stable chunking pattern along the variables, I think storing a CSC matrix could make sense.

One of the places where dask is going to be able to start giving better performance is if you can use parallelism. Is you sampling written in a way where you can take advantage of loading multiple samples in parallel? Are you training in parallel? If neither of these apply, dask is likely only going to add overhead + a nicer API.


But also, we are fairly new to this so are definitely interested to hear from your experiences!

Thank you for your insightful feedback @ivirshup!

load all cells for each variable or all variables for each cell

How would you do that in practice? Is there an option you can set? Storing as CSC?

Just to understand, is it a different set of obs per var sample, or is it always the same?

# Batch one:
obs = [545, 23, 11, 56, 55, 8, 22]
var = [0:200]

# Batch two:
obs = [94, 924, 545, 12, 42, 50, 55]
var = [15200:15400]

Are you training in parallel?

Training is done on several GPUs and each GPU gets a different batch of indices (3 options below).

# option 1:
# Batch one GPU one:
obs = [545, 23, 11, 56, 55, 8, 22]
var = [0:200]

# Batch one GPU two:
obs = [94, 924, 545, 12, 42, 50, 55]
var = [15200:15400]
# option 2:
# Batch one GPU one:
obs = [545, 23, 11, 56, 55, 8, 22]
var = [0:200]

# Batch one GPU two:
obs = [545, 23, 11, 56, 55, 8, 22]
var = [15200:15400]
# option 3:
# Batch one GPU one:
obs = [545, 23, 11, 56, 55, 8, 22]
var = [0:200]

# Batch one GPU two:
obs = [94, 924, 545, 12, 42, 50, 55]
var = [0:200]

Are multiple GPUs (ddp, separate process per GPU) going to be a problem for using "SparseDataset class onto of uncompressed hdf5"?