xarray-contrib/xbatcher

Slow performance using concat_input_dims=True

Opened this issue ยท 10 comments

What is your issue?

I've been doing some ad-hoc performance profiling, and it looks like _drop_input_dims is always the culprit for why my batch generation runs slow. I blame the deep copy here.

However, in light of #164, I don't really see the point of this subroutine. Can someone explain what this subroutine does?

Hi Chris! Thanks for looking into this. Can you share your performance profiling script so other can reproduce the issue.

I blame the deep copy here.

That is not a deep copy. It's a very shallow copy. In general, copying xarray object is extremely fast and cheap. No actual data ever copied.

Can someone explain what this subroutine does?

This is an internal function. It's purpose is described in a comment

# remove input_dims coordinates from datasets, rename the dimensions
# then put intput_dims back in as coordinates

It was very ad hoc profiling, basically just running, stopping, and checking where the stack trace went a number of times. I can try to get something more official but that would take a bit of doing.

If it's not a deep copy, there's something else very wrong. I'm getting gigs of ram use training with a few megs of data with concat_input_dims=True. That line got the most hits, and the next most common was line 317.

The docs make it sound like it does something and then undoes it again. As far as I can see, aside from appending "_input" in some circumstances (which is an issue of its own #164), it doesn't do anything, but that could just be another facet of the "all dims are input dims" problem.

If you're doing profiling, I highly recommend using snakeviz to visualize where you program is spending time.

If you're using more RAM than you expect, it's quite likely that it's a chunk-related issue. Remember that Zarr chunks have to be loaded all at once. It's not possible to do a partial read of a chunk.

If you could share a reproducible snippet of code, it would help us help you.

On it. You can see my surface_currents notebook here.

As far as chunking, everything was chunked by time slice, so I should be good there, right?

I am trying to run your notebook now and I see what you're saying about the performance.

I am just trying to profile getting a single batch

%%prun
for batch in bgen:
    break

It has already been running for 10 minutes... ๐Ÿ™„ This is a tiny piece of data. Something is seriously wrong.

It looks like concat_input_dims=True has a major impact on the speed here. Basically, what is happening is a huge and very inefficient reshaping of the data. This is almost exactly what Xarray's rolling construct function does.

Here is a minimal reproducer for this issue

import numpy as np
import xarray as xr
import xbatcher as xb


# original size was 300, 250; scaled down to make debugging faster
nlat, nlon = 30, 250

ds = xr.Dataset(
    {
        "SST": (('nlat', 'nlon'), np.random.rand(nlat, nlon)),
        "SSH": (('nlat', 'nlon'), np.random.rand(nlat, nlon))
    }
)

bgen = xb.BatchGenerator(
    ds,
    input_dims={'nlon': 3, 'nlat': 3},
    input_overlap={'nlon': 2, 'nlat': 2},
    concat_input_dims=True
)

%time batch = next(iter(bgen))

On the LEAP hub, I'm getting 3.3 s for that very small size example and 36.1 s for the origin 300, 250 shape.

This is suitable for profiling with snakeviz.

%%snakeviz
batch = next(iter(bgen))

Screen Shot 2023-01-30 at 9 58 41 PM

Most of the time is spent on concat, although _drop_input_dims is also significant.

It would be interesting to compare this to rolling.construct to see if it's any more efficient.

If not, I would consider trying to bypass xarray completely internally. It's creating lots of overhead that we don't necessarily need.

Here's an alternative way to accomplish almost the same thing using xarray rolling.construct

batch = (
    ds
    .rolling({"nlat": 3, "nlon": 3})
    .construct({"nlat": "nlat_input", "nlon": "nlon_input"})
    .stack({"input_batch": ("nlat", "nlon")}, create_index=False)
)

For me this ran in 6.95 ms for the full 300 x 250 input (compared to 30 s for the xbatcher method).

With rolling, we don't have the ability to vary input_overlap explicitly. I'd be more than willing to give that up for a 5000x performance boost. ๐Ÿš€ ๐Ÿ˜‰

This post on numpy stride tricks is extremely relevant.