xarray-contrib/xarray-regrid

Improve performance of conservative routine

BSchilperoort opened this issue · 12 comments

We can lessen the performance penalty significantly by doing something like notnull.any([non_grid_dims]) here, in which case we track the nan fraction as any batch slices that have valid data. Could be a reasonable tradeoff, or a configurable argument.

Originally posted by @slevang in #39 (comment)

This is the source of the big performance penalty for now with skipna=True in any case where we have dimensions beyond the regridding dims (e.g. batched regridding over time). The issue is that with this implementation, we are tracking the valid_frac over all the dims, so this normalized weight matrix includes all those extra dimensions and explodes the size of the einsum operations downstream.

Originally posted by @slevang in #39 (comment)

I've been trying out the conservative method on some more realistic workloads, and found the performance comparisons with xesmf not super compelling. Here's a basic example:

import dask.array as da
import xarray as xr
import xarray_regrid

bounds = dict(south=-90, north=90, west=-180, east=180)

source = xarray_regrid.Grid(
    resolution_lat=0.25,
    resolution_lon=0.25,
    **bounds,
).create_regridding_dataset()

target = xarray_regrid.Grid(
    resolution_lat=1,
    resolution_lon=1,
    **bounds,
).create_regridding_dataset()

n_times = 1000

data = da.random.random(
    size=(n_times, source.latitude.size, source.longitude.size),
    chunks=(1, -1, -1),
).astype("float32")

source = xr.DataArray(
    data,
    dims=["time", "latitude", "longitude"],
    coords={
        "time": xr.date_range("2000-01-01", periods=n_times, freq="D"),
        "latitude": source.latitude,
        "longitude": source.longitude,
    }
)

xarray-regrid:

%time source.regrid.conservative(target, skipna=False).compute();
%time source.regrid.conservative(target, skipna=True).compute();
CPU times: user 8min 59s, sys: 9min 37s, total: 18min 37s
Wall time: 44.1 s
CPU times: user 1h 6min 47s, sys: 44min 26s, total: 1h 51min 13s
Wall time: 3min 51s

vs xesmf:

import xesmf as xe
regridder = xe.Regridder(source, target, "conservative")
%time regridder(source, skipna=False).compute()
%time regridder(source, skipna=True).compute();
CPU times: user 4min 9s, sys: 21.8 s, total: 4min 30s
Wall time: 34.5 s
CPU times: user 8min, sys: 35 s, total: 8min 35s
Wall time: 1min 3s

Hm, I get much better performance on a small XPS13 laptop (19 seconds wall time for xarray-regrid with skipna=False, 164 seconds for xESMF). What is your Dask setup? Have you tried setting up dask.distributed?

import dask.distributed
client = dask.distributed.Client()

I am using the latest (non-released) xarray-regrid code, and latest xESMF. For both regridders all CPU threads are 100% occupied during most of the benchmark run.

Dask is complaining about large graph sizes with xESMF though.

Interesting! Do you have opt-einsum installed? xr.dot routes to completely different routines depending on that.

I ran these on a 32 core GCP VM, and only with the default threaded scheduler, so it's definitely worth profiling across other uses. I'll try distributed but wouldn't expect much difference since this is a very straightforward task graph and no impact from the GIL.

With dense weights I definitely see all CPUs churning at full speed, but there are a massive number of 0s in those einsums. Sparse multiplication is algorithmically less efficient but we have a lot less numbers to multiply.

Do you have opt-einsum installed

I did not. Installing it also did not seem to matter.

I ran these on a 32 core GCP VM

I would have expected a much better performance then. I use a 4-core/8 thread Intel i7, and my compute time was <40% of yours for the same code.

the default threaded scheduler

I stopped using that one due to finding it not as reliable (sometimes it's a lot less performant) and more difficult to debug vs the distributed scheduler https://dask-local.readthedocs.io/en/latest/setup/single-distributed.html#single-machine-dask-distributed

Sparse multiplication is algorithmically less efficient but we have a lot less numbers to multiply.

Yeah it makes sense. I did not notice an improvement, but also not drop in performance.

My original benchmarks were weird on the VM, but see #49 (comment) for updated numbers with the addition of sparse weights.

On this particular test, we're significantly better or at least on par with xesmf with skipna=False. skipna=True is where we may still have some room for improvement, so I would welcome any ideas to improve the NaN handling logic.

There are of course many other regridding problems to benchmark. xesmf can now handle chunking in the grid dimensions, although in my experience it is quite slow. I tried flipping the chunking scheme, which was only a few x slower than pancake-style chunks, and something like a 10x advantage over xesmf.

That's wild why is xesmf so inefficient?

Yeah I'm not sure. Here's the xesmf implementation, where you have a 4D weight matrix chunked (y_out, x_out, y_in, x_in), and then apply np.tensordot.

If I change the example to chunks={"time": -1, "latitude": 100, "longitude": 100} I get (Macbook):

  • xarray-regrid: 5s
  • xesmf: 35s

vs chunks={"time": 10, "latitude": -1, "longitude": -1}

  • xarray-regrid: 2.5s
  • xesmf: 4.3s

Interesting comparisons!
xESMF chunking management might not be ideal. There's some flaky code that tries to preserve the spatial chunk sizes when the chunking is not (-1, -1) (the first case above). There is a probability that the choice is suboptimal here, which might add a lot of extra work on the scheduler ?

I added a more detailed benchmark matrix here following the addition of sparse weights (#49) and smarter weight aggregation in xr.dot (#51).

I can't exactly make sense of why xesmf is slow in the chunked case, but I think it's either:

  1. The chunking scheme for the weights isn't optimized
  2. There is a fundamental difference between dot(data, weights_4d) (xesmf) vs dot(data, weight_2d, weights_2d) (xarray-regrid)

In any case we can probably close the issue here since performance is looking pretty good now.

Xesmf tried to be clever about chunks. I'd bet on that being problematic. We should upstream such heuristics now that dark.array is active again.