malariagen/malariagen-data-python

'AssertionError' when trying to return 'variant_allele' from biallelic_snp_calls()

Closed this issue · 3 comments

When I try to retrieve 'variant_allele' data from the results of biallelic_snp_calls(), where I have specified n_snps(), I get an 'AssertionError'.

snp_calls() works fine:

import malariagen_data
recach='/Users/dennistpw/Projects/malariagen_results'
ag3=malariagen_data.Ag3(pre=True, 
                        results_cache=recach)

ds_snps = ag3.snp_calls(sample_sets='AG1000G-AO',
                        region='2L:1000000-1010000')

ds_snps['variant_allele'].compute()

Returns:

array([[b'T', b'A', b'C', b'G'],
       [b'G', b'A', b'C', b'T'],
       [b'T', b'A', b'C', b'G'],
       ...,
       [b'C', b'A', b'T', b'G'],
       [b'G', b'A', b'C', b'T'],
       [b'G', b'A', b'C', b'T']], dtype='|S1')

biallelic_snps_calls() also works fine:

ds_snps_bi = ag3.biallelic_snp_calls(sample_sets='AG1000G-AO',
                                     region='2L:1000000-1010000')

ds_snps_bi['variant_allele'].compute()

Returns:

array([[b'A', b'T'],
       [b'G', b'T'],
       [b'G', b'A'],
       ...,
       [b'C', b'T'],
       [b'G', b'A'],
       [b'G', b'T']], dtype='|S1')

biallelic_snp_calls() where I have specified n_snps()...

ds_snps_bi_sub = ag3.biallelic_snp_calls(sample_sets='AG1000G-AO',
                                     region='2L:1000000-5000000',
                                     n_snps=2000)
ds_snps_bi_sub['variant_allele'].compute()

Returns:

{
	"name": "AssertionError",
	"message": "",
	"stack": "---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[12], line 1
----> 1 ds_snps_bi_sub['variant_allele'].compute()

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/xarray/core/dataarray.py:1101, in DataArray.compute(self, **kwargs)
   1082 \"\"\"Manually trigger loading of this array's data from disk or a
   1083 remote source into memory and return a new array. The original is
   1084 left unaltered.
   (...)
   1098 dask.compute
   1099 \"\"\"
   1100 new = self.copy(deep=False)
-> 1101 return new.load(**kwargs)

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/xarray/core/dataarray.py:1075, in DataArray.load(self, **kwargs)
   1057 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1058     \"\"\"Manually trigger loading of this array's data from disk or a
   1059     remote source into memory and return this array.
   1060 
   (...)
   1073     dask.compute
   1074     \"\"\"
-> 1075     ds = self._to_temp_dataset().load(**kwargs)
   1076     new = self._from_temp_dataset(ds)
   1077     self._variable = new._variable

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/xarray/core/dataset.py:747, in Dataset.load(self, **kwargs)
    744 import dask.array as da
    746 # evaluate all the dask arrays simultaneously
--> 747 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    749 for k, data in zip(lazy_data, evaluated_data):
    750     self.variables[k].data = data

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/base.py:599, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    596     keys.append(x.__dask_keys__())
    597     postcomputes.append(x.__dask_postcompute__())
--> 599 results = schedule(dsk, keys, **kwargs)
    600 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state[\"cache\"][key] = res

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in <genexpr>(.0)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError(\"Expected %d args, got %d\" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/miniconda3/envs/malariagen_plink/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/Projects/malariagen-data-python/malariagen_data/anoph/snp_data.py:1629, in AnophelesSnpData.biallelic_snp_calls.<locals>.<lambda>(block)
   1626 variant_allele = ds_bi[\"variant_allele\"].data
   1627 variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
   1628 variant_allele_out = da.map_blocks(
-> 1629     lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1),
   1630     variant_allele,
   1631     dtype=variant_allele.dtype,
   1632     chunks=(variant_allele.chunks[0], [2]),
   1633 )
   1634 data_vars[\"variant_allele\"] = (\"variants\", \"alleles\"), variant_allele_out
   1636 # Store allele counts, transformed, so we don't have to recompute.

File ~/Projects/malariagen-data-python/malariagen_data/util.py:1281, in apply_allele_mapping()
   1279 n_sites = x.shape[0]
   1280 n_alleles = x.shape[1]
-> 1281 assert mapping.shape[0] == n_sites
   1282 assert (
   1283     mapping.shape[1] == n_alleles
   1284 )  # these are not the same, work out what's going on - try running code with debugger? or print statementsd
   1286 # Create output array.

AssertionError: "
}

This looks like some kind of mismatch in the expected size of arrays in the apply_allele_mapping()

Sorry it's taken me a while to come back to this!

To my (very inexperienced) eye - it looks like the source of the bug is when da.map_blocks tries to apply over the allele_mapping, and the variant_allele arrays. variant_allele is a chunked dask array, whereas allele_mapping is an in-memory numpy array. When map_blocks is run, it is passed the chunks of the variant_allele, and the entire allele_mapping array, so the test for size...

    n_sites = x.shape[0]
    n_alleles = x.shape[1]
    assert mapping.shape[0] == n_sites
    assert mapping.shape[1] == n_alleles

...fails, as it is comparing the size of a chunk of variant_allele to the entire allele_mapping array.
I've attempted to fix this in the my PR by chunking allele_mapping according to variant_allele (see here). Now it seems to be working OK.

Hi @tristanpwdennis, nice work getting to the bottom of this one, the fix you have in #515 LGTM, thanks so much!

Just to update, I rolled the fix for this into other work I was doing on improving the biallelic SNP calls functions, via #623. Thanks again for figuring it out!