Utility for loading dask array from SparseDataset
ivirshup opened this issue · 1 comments
ivirshup commented
Please describe your wishes and possible alternatives to achieve the desired result.
We should have a utility available for creating a sparse dask array from a sparse dataset.
Basic implementation (needs CSC support too)
from scipy import sparse
import numpy as np
import dask.array as da
class CSRCallable:
"""Dummy class to bypass dask checks"""
def __new__(cls, shape, dtype) -> sparse.csr_matrix:
if len(shape) == 0:
shape = (0, 0)
if len(shape) == 1:
shape = (shape[0], 0)
elif len(shape) == 2:
pass
else:
raise ValueError(shape)
return sparse.csr_matrix(shape, dtype=dtype)
def make_dask_chunk(
x: ad.experimental.SparseDataset, start: int, end: int
) -> DaskArray:
import dask.array as da
from dask import delayed
def take_slice(x, idx):
return x[idx]
return da.from_delayed(
delayed(take_slice)(x, slice(start, end)),
dtype=x.dtype,
shape=(end - start, x.shape[1]),
meta=CSRCallable,
)
def sparse_dataset_as_dask(x, stride: int):
import dask.array as da
n_chunks, rem = divmod(x.shape[0], stride)
chunks = []
cur_pos = 0
for i in range(n_chunks):
chunks.append(make_dask_chunk(x, cur_pos, cur_pos + stride))
cur_pos += stride
if rem:
chunks.append(make_dask_chunk(x, cur_pos, x.shape[0]))
return da.concatenate(chunks, axis=0)
I'm not sure what this should be called, or where it should go. Probably experimental for now.
From #1348 (comment)
cc: @ilan-gold