pydata/xarray

why time grouping doesn't preserve chunks

rabernat opened this issue · 30 comments

Code Sample, a copy-pastable example if possible

I am continuing my quest to obtain more efficient time grouping for calculation of climatologies and climatological anomalies. I believe this is one of the major performance bottlenecks facing xarray users today. I have raised this in other issues (e.g. #1832), but I believe I have narrowed it down here to a more specific problem.

The easiest way to summarize the problem is with an example. Consider the following dataset

import xarray as xr
ds = xr.Dataset({'foo': (['x'], [1, 1, 1, 1])},
                coords={'x': (['x'], [0, 1, 2, 3]),
                        'bar': (['x'], ['a', 'a', 'b', 'b']),
                        'baz': (['x'], ['a', 'b', 'a', 'b'])})
ds = ds.chunk({'x': 2})
ds
<xarray.Dataset>
Dimensions:  (x: 4)
Coordinates:
  * x        (x) int64 0 1 2 3
    bar      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>
    baz      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>
Data variables:
    foo      (x) int64 dask.array<shape=(4,), chunksize=(2,)>

One non-dimension coordinate (bar) is contiguous with respect to x while the other baz is not. This is important. baz is structured similar to the way that month would be distributed on a timeseries dataset.

Now let's do a trivial groupby operation on bar that does nothing, just returns the group unchanged:

ds.foo.groupby('bar').apply(lambda x: x)
<xarray.DataArray 'foo' (x: 4)>
dask.array<shape=(4,), dtype=int64, chunksize=(2,)>
Coordinates:
  * x        (x) int64 0 1 2 3
    bar      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>
    baz      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>

This operation preserved this original chunks in foo. But if we group by baz we see something different

ds.foo.groupby('baz').apply(lambda x: x)
<xarray.DataArray 'foo' (x: 4)>
dask.array<shape=(4,), dtype=int64, chunksize=(4,)>
Coordinates:
  * x        (x) int64 0 1 2 3
    bar      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>
    baz      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>

Problem description

When grouping over a non-contiguous variable (baz) the result has no chunks. That means that we can't lazily access a single item without computing the whole array. This has major performance consequences that make it hard to calculate anomaly values in a more realistic case. What we really want to do is often something like

ds = xr.open_mfdataset('lots/of/files/*.nc')
ds_anom = ds.groupby('time.month').apply(lambda x: x - x.mean(dim='time)

It is currently impossible to do this lazily due to the issue described above.

Expected Output

We would like to preserve the original chunk structure of foo.

Output of xr.show_versions()

xr.show_versions() is triggering a segfault right now on my system for unknown reasons! I am using xarray 0.10.7.

Nice write up @rabernat ! Note that the behavior is the same with chunks of size 1 (first thing I tried).

Short understanding question: while your example shows that chunks are lost after the groupby, does that prove for sure that the groupby operation does not use the original chunks?

(side note: the quest for climatologies is a rightful quest: see my comment about the cds)

while your example shows that chunks are lost after the groupby, does that prove for sure that the groupby operation does not use the original chunks?

One way to answer that is the following:

Here is the dask graph for groupby('bar'):
image

Here is the dask graph for groupby('baz'):
image

I agree that single value chunks illustrates the problem more clearly. I think this example is most clean if you do it like this

import xarray as xr
import dask.array as dsa
ds = xr.Dataset({'foo': (['x'], dsa.ones(4, chunks=1))},
                coords={'x': (['x'], [0, 1, 2, 3]),
                        'bar': (['x'], ['a', 'a', 'b', 'b']),
                        'baz': (['x'], ['a', 'b', 'a', 'b'])})

ds.foo.groupby('bar').apply(lambda x: x).data.visualize():
image

ds.foo.groupby('baz').apply(lambda x: x).data.visualize()
image

And just because it's fun, I will show what the anomaly calculation looks like

ds.foo.groupby('bar').apply(lambda x: x - x.mean()).data.visualize():

image

ds.foo.groupby('baz').apply(lambda x: x - x.mean()).data.visualize():

image

It looks like everything is really ok up until the very end, where all the tasks aggregate into a single getitem call.

The source of the indexing operation that brings all the chunks together is the _maybe_reorder helper function, which "scatters" array elements back into the correct positions after applying the grouped function:

def _maybe_reorder(xarray_obj, dim, positions):
order = _inverse_permutation_indices(positions)
if order is None:
return xarray_obj
else:
return xarray_obj[{dim: order}]

So basically the issue comes down to indexing with dask.array, which creates a single chunk when integers indices are not all in order:

import dask.array as da
import numpy as np

x = da.ones(4, chunks=1)
print(x[np.arange(4)])
# dask.array<getitem, shape=(4,), dtype=float64, chunksize=(1,)>
print(x[np.arange(4)[::-1]])
# dask.array<getitem, shape=(4,), dtype=float64, chunksize=(4,)>

As a work-around in xarray, we could use explicit indexing + concatenation.

Thanks for the explanation @shoyer! Yes, that appears to be the root of the issue. After literally years of struggling with this, I am happy to finally get to this level of clarity.

So basically the issue comes down to indexing with dask.array, which creates a single chunk when integers indices are not all in order

Do we think dask is happy with that behavior? If not, then an upstream fix would be best. Pinging @mrocklin.

Otherwise we can try to work around in xarray.

I vaguely recall discussing chunks that result from indexing somewhere in the dask issue tracker (when we added the special case for a monotonic increasing indexer to preserve chunks), but I can't find it now.

I think the challenge is that it isn't obvious what the right chunksizes should be. Chunks that are too small also have negative performance implications. Maybe the automatic chunking logic that @mrocklin has been looking into recently would be relevant here.

With groupby in xarray, we have two main cases:

  1. groupby with reduction -- (e.g. ds.groupby('baz').mean(dim='x')). There is currently no problem here. The new dimension becomes baz and the array is chunked as {'baz': 1}.
  2. groupby with no reduction -- (e.g. ds.groubpy('baz').apply(lambda x: x - x.mean())). In this case, the point of the out-of-order indexing is actually to put the array back together in its original order. In my last example above, according to the dot graph, it looks like there are four chunks right up until the end. They just have to be re-ordered. I imagine this should be cheap and simple, but I am probably overlooking something.

Case 2 seems similar to @shoyer's example: x[np.arange(4)[::-1]. Here we would just want to reorder the existing chunks.

If the chunk size before reindexing is not 1, then yes, one needs to do something more sophisticated. But I would argue that, if the array is being re-indexed along a dimension in which the chunk size is 1, a sensible default behavior would be to avoid aggregating into a big chunk and instead just pass the original chunks though in a new order.

OK, so lowering down to a dask array conversation, lets look at a couple examples. First, lets look at the behavior of a sorted index:

import dask.array as da
x = da.ones((20, 20), chunks=(4, 5))
x.chunks
# ((4, 4, 4, 4, 4), (5, 5, 5, 5))

If we index that array with a sorted index, we are able to efficiently preserve chunking:

import numpy as np

x[np.arange(20), :].chunks
# ((4, 4, 4, 4, 4), (5, 5, 5, 5))

x[np.arange(20) // 2, :].chunks
# ((8, 8, 4), (5, 5, 5, 5))

However if the index isn't sorted then everything goes into one big chunk:

x[np.arange(20) % 3, :].chunks
# ((20,), (5, 5, 5, 5))

We could imagine a few alternatives here:

  1. Make a chunk for every element in the index
  2. Make a chunk for every contiguous run in the index. So here we would have chunk dimensions of size 3 matching the 0, 1, 2, 0, 1, 2, 0, 1, 2 pattern of our index.

I don't really have a strong intuition for how the xarray operations transform into dask array operations (my brain is a bit tired right now, so thinking is hard) but my guess is that they would benefit from the second case. (A pure dask.array example would be welcome).

Now we have to consider how enacting a policy like "put contiguous index regions into the same chunk" might go wrong, and how we might defend against it generally.

x = da.ones(10000, chunks=(100,))  # 100 chunks of size 100
index = np.array([0, 100, 200, 300, ..., 1, 101, 201, 301, ..., 2, 102, 202, 302, ...])
x[index]

In the example above we have a hundred input chunks and a hundred contiguous regions in our index. Seems good. However each output chunk touches each input chunk, so this will likely create 10,000 tasks, which we should probably consider a fail case here.

So we learn that we need to look pretty carefully at how the values within the index interact with the chunk structure in order to know if we can do this well. This isn't an insurmountable problem, but isn't trivial either.

In principle we're looking for a function that takes in two inputs:

  1. The chunks of a single dimension like x.chunks[i] or (4, 4, 4, 4, 4) from our first example
  2. An index like np.arange(20) % 3 from our first example

And outputs a bunch of smaller indexes to pass on to various chunks. However, it hopefully does this in a way that is efficient, and fails early if it's going to emit a bunch of very small slices.

It's also probably worth thinking about the kind of operations you're trying to do, and how streamable they are. For example, if you were to take a dataset that was partitioned chronologically by month and then do some sort of day-of-month grouping then that would require the full dataset to be in memory at once.

If you're doing something like grouping on every month (keeping months of different years separate) then presumably your index is already sorted, and so you should be fine with the current behavior.

It might be useful to take a look at how the various XArray cases you care about convert to dask array slicing operations.

Here's an example of what these indices look like for a slightly more realistic groupby example:

import xarray
import pandas
import numpy as np

array = xarray.DataArray(
    range(1000), [('time', pandas.date_range('2000-01-01', freq='D', periods=1000))])

# this works with xarray 0.10.7
xarray.core.groupby._inverse_permutation_indices(
    array.groupby('time.month')._group_indices)
array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  93,  94,  95,  96,  97,  98,  99, 100,
       101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
       114, 115, 116, 117, 118, 119, 120, 121, 178, 179, 180, 181, 182,
       183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195,
       196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
       271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
       284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296,
       297, 298, 299, 300, 361, 362, 363, 364, 365, 366, 367, 368, 369,
       370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382,
       383, 384, 385, 386, 387, 388, 389, 390, 391, 454, 455, 456, 457,
       458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470,
       471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483,
       544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556,
       557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569,
       570, 571, 572, 573, 574, 637, 638, 639, 640, 641, 642, 643, 644,
       645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657,
       658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 730, 731, 732,
       733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745,
       746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758,
       759, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827,
       828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840,
       841, 842, 843, 844, 845, 846, 878, 879, 880, 881, 882, 883, 884,
       885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897,
       898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 938, 939, 940,
       941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953,
       954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966,
       967, 968,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
        42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,
        55,  56,  57,  58,  59,  60,  61, 122, 123, 124, 125, 126, 127,
       128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
       141, 142, 143, 144, 145, 146, 147, 148, 149, 209, 210, 211, 212,
       213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
       226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
       239, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312,
       313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
       326, 327, 328, 329, 330, 392, 393, 394, 395, 396, 397, 398, 399,
       400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412,
       413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 484, 485, 486,
       487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499,
       500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512,
       513, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586,
       587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599,
       600, 601, 602, 603, 604, 605, 668, 669, 670, 671, 672, 673, 674,
       675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687,
       688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 760, 761,
       762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774,
       775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787,
       788, 789, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857,
       858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870,
       871, 872, 873, 874, 875, 876, 877, 908, 909, 910, 911, 912, 913,
       914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926,
       927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 969, 970,
       971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983,
       984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996,
       997, 998, 999,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
        72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
        85,  86,  87,  88,  89,  90,  91,  92, 150, 151, 152, 153, 154,
       155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
       168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 240, 241, 242,
       243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255,
       256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268,
       269, 270, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341,
       342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354,
       355, 356, 357, 358, 359, 360, 423, 424, 425, 426, 427, 428, 429,
       430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442,
       443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 514, 515,
       516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528,
       529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541,
       542, 543, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616,
       617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629,
       630, 631, 632, 633, 634, 635, 636, 699, 700, 701, 702, 703, 704,
       705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717,
       718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 790,
       791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803,
       804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815])

I think it would work with the "put contiguous index regions into the same chunk" heuristic.

On the other hand, this could break pretty badly for other group-by operations, e.g., calculating those anomalies by day of year instead:

xarray.core.groupby._inverse_permutation_indices(
    array.groupby('time.dayofyear')._group_indices)
array([  0,   3,   6,   9,  12,  15,  18,  21,  24,  27,  30,  33,  36,
        39,  42,  45,  48,  51,  54,  57,  60,  63,  66,  69,  72,  75,
        78,  81,  84,  87,  90,  93,  96,  99, 102, 105, 108, 111, 114,
       117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147, 150, 153,
       156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 186, 189, 192,
       195, 198, 201, 204, 207, 210, 213, 216, 219, 222, 225, 228, 231,
       234, 237, 240, 243, 246, 249, 252, 255, 258, 261, 264, 267, 270,
       273, 276, 279, 282, 285, 288, 291, 294, 297, 300, 303, 306, 309,
       312, 315, 318, 321, 324, 327, 330, 333, 336, 339, 342, 345, 348,
       351, 354, 357, 360, 363, 366, 369, 372, 375, 378, 381, 384, 387,
       390, 393, 396, 399, 402, 405, 408, 411, 414, 417, 420, 423, 426,
       429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 462, 465,
       468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 504,
       507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543,
       546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582,
       585, 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621,
       624, 627, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660,
       663, 666, 669, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699,
       702, 705, 708, 711, 714, 717, 720, 723, 726, 729, 732, 735, 738,
       741, 744, 747, 750, 753, 756, 759, 762, 765, 768, 771, 774, 777,
       780, 783, 786, 789, 792, 795, 798, 801, 804, 807, 809, 811, 813,
       815, 817, 819, 821, 823, 825, 827, 829, 831, 833, 835, 837, 839,
       841, 843, 845, 847, 849, 851, 853, 855, 857, 859, 861, 863, 865,
       867, 869, 871, 873, 875, 877, 879, 881, 883, 885, 887, 889, 891,
       893, 895, 897, 899, 901, 903, 905, 907, 909, 911, 913, 915, 917,
       919, 921, 923, 925, 927, 929, 931, 933, 935, 937, 939, 941, 943,
       945, 947, 949, 951, 953, 955, 957, 959, 961, 963, 965, 967, 969,
       971, 973, 975, 977, 979, 981, 983, 985, 987, 989, 991, 993, 995,
       997, 999,   1,   4,   7,  10,  13,  16,  19,  22,  25,  28,  31,
        34,  37,  40,  43,  46,  49,  52,  55,  58,  61,  64,  67,  70,
        73,  76,  79,  82,  85,  88,  91,  94,  97, 100, 103, 106, 109,
       112, 115, 118, 121, 124, 127, 130, 133, 136, 139, 142, 145, 148,
       151, 154, 157, 160, 163, 166, 169, 172, 175, 178, 181, 184, 187,
       190, 193, 196, 199, 202, 205, 208, 211, 214, 217, 220, 223, 226,
       229, 232, 235, 238, 241, 244, 247, 250, 253, 256, 259, 262, 265,
       268, 271, 274, 277, 280, 283, 286, 289, 292, 295, 298, 301, 304,
       307, 310, 313, 316, 319, 322, 325, 328, 331, 334, 337, 340, 343,
       346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376, 379, 382,
       385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418, 421,
       424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460,
       463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499,
       502, 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538,
       541, 544, 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577,
       580, 583, 586, 589, 592, 595, 598, 601, 604, 607, 610, 613, 616,
       619, 622, 625, 628, 631, 634, 637, 640, 643, 646, 649, 652, 655,
       658, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688, 691, 694,
       697, 700, 703, 706, 709, 712, 715, 718, 721, 724, 727, 730, 733,
       736, 739, 742, 745, 748, 751, 754, 757, 760, 763, 766, 769, 772,
       775, 778, 781, 784, 787, 790, 793, 796, 799, 802, 805, 808, 810,
       812, 814, 816, 818, 820, 822, 824, 826, 828, 830, 832, 834, 836,
       838, 840, 842, 844, 846, 848, 850, 852, 854, 856, 858, 860, 862,
       864, 866, 868, 870, 872, 874, 876, 878, 880, 882, 884, 886, 888,
       890, 892, 894, 896, 898, 900, 902, 904, 906, 908, 910, 912, 914,
       916, 918, 920, 922, 924, 926, 928, 930, 932, 934, 936, 938, 940,
       942, 944, 946, 948, 950, 952, 954, 956, 958, 960, 962, 964, 966,
       968, 970, 972, 974, 976, 978, 980, 982, 984, 986, 988, 990, 992,
       994, 996, 998,   2,   5,   8,  11,  14,  17,  20,  23,  26,  29,
        32,  35,  38,  41,  44,  47,  50,  53,  56,  59,  62,  65,  68,
        71,  74,  77,  80,  83,  86,  89,  92,  95,  98, 101, 104, 107,
       110, 113, 116, 119, 122, 125, 128, 131, 134, 137, 140, 143, 146,
       149, 152, 155, 158, 161, 164, 167, 170, 173, 176, 179, 182, 185,
       188, 191, 194, 197, 200, 203, 206, 209, 212, 215, 218, 221, 224,
       227, 230, 233, 236, 239, 242, 245, 248, 251, 254, 257, 260, 263,
       266, 269, 272, 275, 278, 281, 284, 287, 290, 293, 296, 299, 302,
       305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335, 338, 341,
       344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377, 380,
       383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419,
       422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458,
       461, 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497,
       500, 503, 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536,
       539, 542, 545, 548, 551, 554, 557, 560, 563, 566, 569, 572, 575,
       578, 581, 584, 587, 590, 593, 596, 599, 602, 605, 608, 611, 614,
       617, 620, 623, 626, 629, 632, 635, 638, 641, 644, 647, 650, 653,
       656, 659, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689, 692,
       695, 698, 701, 704, 707, 710, 713, 716, 719, 722, 725, 728, 731,
       734, 737, 740, 743, 746, 749, 752, 755, 758, 761, 764, 767, 770,
       773, 776, 779, 782, 785, 788, 791, 794, 797, 800, 803, 806])

This looks like @mrocklin's second case.

That said, it's still probably more graceful to fail by creating too many small tasks rather than one giant task.

That said, it's still probably more graceful to fail by creating too many small tasks rather than one giant task.

Maybe. We'll blow out the scheduler with too many tasks. With one large task we'll probably just start losing workers from memory errors.

In your example what does the chunking of the indexed array likely to look like? How is the interaction between contiguous regions of the index and the chunk structure of the indexed array?

In your example what does the chunking of the indexed array likely to look like? How is the interaction between contiguous regions of the index and the chunk structure of the indexed array?

Assuming the original array is chunked into one file per year-month (which is probably a reasonable starting point):

  • For the groupby('time.month') example: each contiguous run of indices should be indexing a contiguous chunk. This case should work nicely.
  • For the groupby('time.dayofyear') example: each index will be pulling data from a different chunk. This is still a bit of a fail case for the scheduler.

Another option would be to rewrite how xarray does groupby/transform operations to make it more dask friendly. Currently it looks roughly like:

def groupby_transform(array, list_of_group_indices, func):
    # create a list of sub-arrays for each group
    subarrays = [array[indices] for indices in list_of_group_indices]
    # apply the function
    applied = [func(x) for x for x in subarrays]
    # concatenate applied arrays together
    concatenated = np.concatenate(applied)
    # restore original order
    reordered = concatenated[indices_to_restore_orig_order]
    return reordered

For example, we could reverse the order of the last two steps.

So my question was "if you're grouping data by month, and it's already partitioned by month, then why are the indices out of order?" However it may be that you've answer this in your most recent comment, I'm not sure. It may also be that I'm not understanding the situation.

Some sort of automatic rechunking could also make a big difference for performance, in cases where the groupby operation splits the original chunks into small pieces (like my groupby('time.dayofyear') example). Applying dask functions on arrays with many small chunks will be slow.

So if you're willing to humor me for a moment with dask.array examples, if you have an array that's currently partitioned by month:

x = da.ones((1000, ...), chunks=(30, ...))  # approximately

And you do something by time.dayofyear, what do you end up doing to the array in dask array operations? Sorry to be a bit slow here. I'm not as familiar with how XArray translates its groupby operations to dask.array operations under the hood.

I'm not as familiar with how XArray translates its groupby operations to dask.array operations under the hood.

No worries, this is indeed, pretty confusing!

For time.dayofyear in my groupby_transform pseudocode above (#2237 (comment)):

# suppose N is the number of years of data
list_of_group_indices = [
    [0, 365, 730, ..., (N-1)*365],  # day 1, ordered by year
    [1, 366, 731, ..., (N-1)*365 + 1],  # day 2, ordered by year
    ...
]
indices_to_restore_orig_order = [
    0, N, 2*N, 3*N, ...,  # year 1, ordered by day
    1, N+1, 2*N+1, 3*N+1, ...,  # year 2, ordered by day
    ...
]

As you can see, if you concatenate together the first set of indices and index by the second set of indices, it would arrange them into sequential integers.

Thanks. This example helps.

As you can see, if you concatenate together the first set of indices and index by the second set of indices, it would arrange them into sequential integers.

I'm not sure I understand this.

The situation on the whole does seem sensible to me though. This starts to look a little bit like a proper shuffle situation (using dataframe terminology). Each of your 365 output partitions would presumably touch 1/12th of your input partitions, leading to a quadratic number of tasks. If after doing something you then wanted to rearrange your data back then presumably that would require an equivalent number of extra tasks.

Am I understanding the situation correctly?

As you can see, if you concatenate together the first set of indices and index by the second set of indices, it would arrange them into sequential integers.

I'm not sure I understand this.

Maybe it helps to think about these as matrices. The nth row of indices_to_restore_orig_order pulls out elements corresponding to the nth column of list_of_group_indices.

The situation on the whole does seem sensible to me though. This starts to look a little bit like a proper shuffle situation (using dataframe terminology). Each of your 365 output partitions would presumably touch 1/12th of your input partitions, leading to a quadratic number of tasks. If after doing something you then wanted to rearrange your data back then presumably that would require an equivalent number of extra tasks.

Yes, this is definitely a shuffle.

I'm glad to see that this has generated so much serious discussion and thought! I will try to catch up on it in the morning when I have some hope of understanding.

I've implemented something here: dask/dask#3648

Playing with it would be welcome.

Can this be closed or is there something to do on the xarray side now that dask/dask#3648 has been merged?

The original issue has been fixed, at least in the toy example:

>>> ds.foo.groupby('baz').apply(lambda x: x)
<xarray.DataArray 'foo' (x: 4)>
dask.array<shape=(4,), dtype=int64, chunksize=(1,)>
Coordinates:
  * x        (x) int64 0 1 2 3
    bar      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>
    baz      (x) <U1 dask.array<shape=(4,), chunksize=(2,)>

I don't know if it's still an issue in more realistic scenarios.

We had a long iteration on this in Pangeo, and big progress was made in dask. Definitely closed for now.

I'm reviving this classic issue to report another quasi-failure of dask chunking, this time in the opposite direction.

Consider this dataset:

import xarray as xr
ds = xr.Dataset({'foo': (['time'], dsa.ones(120, chunks=60))},
                coords={'year': (['time'], np.repeat(np.arange(10), 12))})
<xarray.Dataset>
Dimensions:  (time: 120)
Coordinates:
    year     (time) int64 0 0 0 0 0 0 0 0 0 0 0 0 1 ... 9 9 9 9 9 9 9 9 9 9 9 9
Dimensions without coordinates: time
Data variables:
    foo      (time) float64 dask.array<chunksize=(60,), meta=np.ndarray>

There are just two big chunks.

Now let's try to take an "annual mean" using resample

ds.foo.groupby('year').mean(dim='time')
<xarray.DataArray 'foo' (year: 10)>
dask.array<stack, shape=(10,), dtype=float64, chunksize=(1,), chunktype=numpy.ndarray>
Coordinates:
  * year     (year) int64 0 1 2 3 4 5 6 7 8 9

Now we have a chunksize of 1 and 10 chunks. That's bad: we should still just have two chunks, since we are aggregating only within chunks. Taken to the limit of very high temporal resolution, this example will blow up in terms of number of tasks. I wish dask could figure out that it doesn't have to create all those tasks.

The graph looks like this
image

In contrast, coarsen is smart enough, probably because it relies on dask's underlying coarsen function

ds.foo.coarsen(time=12).mean()
<xarray.DataArray (time: 10)>
dask.array<mean_agg-aggregate, shape=(10,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
Coordinates:
    year     (time) float64 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0
Dimensions without coordinates: time

I think the behaviour in Ryan's most recent comment is a consequence of groupby.mean being

results = []
for group_idx in group_indices:  # one group per year
    group = ds.isel(group_idx)  # (SPLIT)
	results.append(group.mean()) # (APPLY)
return xr.concat(results, dim="year") # COMBINE results in one chunk per year (one chunk per element in results)

I think the fundamental question is: Is it really possible for dask to recognize that the chunk structure after the combine step could be consolidated with an arbitrary number of apply steps in the middle ? OR When a computation maps a single chunk to many chunks, should dask consolidate the output chunks (using array.chunk-size)?

We can explicitly ask for consolidation of chunks by saying the output should be chunked 5 along year

dask.config.set({"optimization.fuse.ave-width": 6})  # note > 5
(
    ds.foo.groupby("year")
    .mean(dim="time")
    .chunk({"year": 5})  # really important, why and how would dask choose this automatically/
    .data.visualize(optimize_graph=False)
)

image

Then if we set optimization.fuse.ave-width appropriately, we get the graph we want after optimization

dask.config.set({"optimization.fuse.ave-width": 6})
(
    ds.foo.groupby("year")
    .mean(dim="time")
    .chunk({"year": 5})  # really important 
    .data.visualize(optimize_graph=True)
)

image

Can we make dask recognize that the 5 getitem tasks from input-chunk-0, at the bottom of each tower, can be fused to a single task? In that case, fuse the 5 getitem tasks and "propagate" that fusion up the tower.

I guess another failure here is that when fuse.ave-width is 3 (< width of tower), why isn't dask fusing to make three "sub-towers" per-tower? Even that would help reduce number of tasks.

dask.config.set({"optimization.fuse.ave-width": 3})
(
    ds.foo.groupby("year")
    .mean(dim="time")
    .chunk({"year": 5})  # really important 
    .data.visualize(optimize_graph=True)
)

image

Reading up on fusion, the docstring says

This optimization applies to all reductions–tasks that have at most one dependent–so it may be viewed as fusing “multiple input, single output” groups of tasks into a single task.

So we need the opposite : fuse "single input, multiple output" to a single task when some appropriate heuristic is satisfied.

Fixed on main with ds.groupby("year").mean(method="blockwise")

image