xarray-contrib/xbatcher

Generating the batches seems slow

Opened this issue · 15 comments

I've just come across xbatcher, and I think it could be just what I need for using CNNs on data stored in dask-backed xarrays. I've got a number of questions about how it works, and some issues I'm having. If this isn't the appropriate place for these questions then please let me know, and I'll direct them elsewhere. I decided not to create issues for each question, as I expect a number of them aren't actually problems with xbatcher, they're problems with my understanding instead - that would clog up the issues board - but if some of these questions need extracting to a separate issue then I'm happy to do that.

Firstly, thanks for putting this together - it has already solved a lot of problems for me.

To give some context, I'm trying to use xbatcher to run batches of image data through a pytorch CNN on Microsoft Planetary Computer. I'm not doing training here, I'm just doing inference - so I just need to push the raw array data through the ML model and get the results out.

Now on to the questions:

1. Generating the batches seems slow
I'm trying to create batches from a DataArray of size (8172, 8082), which is a single band of a satellite image. I'm using the following call to create a BatchGenerator:

patch_size = 64
batch_size = 10
bgen = b1.batch.generator(input_dims=dict(x=patch_size, y=patch_size),
                          batch_dims=dict(x=batch_size*patch_size, y=batch_size*patch_size),
                          concat_input_dims=True, preload_batch=False)

That should create DataArrays that are 64 x 64 (in x and y), with 100 of those entries in the batch.

I'm then running a loop over the batch generator, doing something with the batches. We'll come to what I'm doing later - but for the moment lets just append the result to a list:

results = []
for batch in tqdm.tqdm(bgen):
    results.append(batch)

This takes around 1s per batch, and creates a very small Dask task that goes away and generates the batch (I've already run b1.persist() to ensure all the data is on the Dask cluster). I have a few questions about this:

a) Is this sort of speed expected? From some rough calculations at 1s per batch, for a 64 x 64 batch, it'll take hours to batch up my ~8000x8000 array)
b) With preload_batch=False I'd expect these to be generated lazily - and it does seem that the underlying data in the DataArray is a dask array - however it still seems to take around a second per batch.
c) Should I be approaching this in a different way to get a better speed?

2. How do you put batches back together after processing?
My machine learning model is producing a single value as an output, so for a batch of 100 64x64 patches, I get an output of a 100-element array. What's the best way of putting this back into a DataArray that has the same format/co-ordinates as the original input array? I'd be happy with either an array with dimensions of original_size / 64 in both the x and y dimension, or an array of the same size as the input with the single output value repeated for each of the input pixels in that batch.

I've tried to put some of this together myself, but it seems that the x co-ordinate value in the batch DataArray is the same for each batch. I'd have thought this would represent the x co-ordinates that had been extracted from the original DataArray, but it doesn't seem to. For example, if I run:

batches = []
for i, batch in enumerate(bgen):
    batches.append(batch)
    if i == 1:
        break

to get the first two batches, I can then compare their x co-ordinate values:

np.all(batches[0].to_array().squeeze().x == batches[1].to_array().squeeze().x)

and it shows that they're all equal.

Do you have any ideas as to what I could do to be able to put the batches back together?

3. Documentation and tutorial notebook
It took me quite a while to find the example notebook that is sitting in #31 - but the notebook was really helpful (actually a lot more helpful than the documentation on ReadTheDocs). Could this be merged soon, and a prominent link put to it in the docs/README? I think this would significantly help any other users trying to get to grips with xbatcher.

4. Overlap seems to cause hang
Trying to batch the array with an overlap seems to take ages to do anything - I'm not sure whether it has hung or is just taking a long time to do the computations. If I run:

patch_size = 64
batch_size = 10
bgen = b1.batch.generator(input_dims=dict(x=patch_size, y=patch_size),
                          batch_dims=dict(x=batch_size*patch_size, y=batch_size*patch_size),
                          input_overlap=dict(x=patch_size-1, y=patch_size-1),
                          concat_input_dims=True,
                          preload_batch=False)

and then try and get the first two batches:

batches = []
for i, batch in enumerate(bgen):
    batches.append(batch)
    if i == 1:
        break

I leave it running for a few minutes, and nothing seems to happen. When I interrupt it, it seems to be deep inside xarray/pandas index handling code.

Any idea what's going on?

Is that overlapping a lot so getting a really huge number?

@RichardScottOZ Yes, it's overlapping a lot, but I thought it would do a lazy generation of batches, so wouldn't take long to generate the first couple?

1) IME, batch generation does seem slow, I saw a major slowdown when I implemented it in my current project.

4) It's probably just taking a long time, I'm seeing the expected behavior when I use overlapping. Are you sure you want input_overlap=dict(x=patch_size-1, y=patch_size-1) and not input_overlap=dict(x=1, y=1)?

Thanks @cmdupuis3 - interesting you found it slow too.

I thought I had the input_overlap definition right - but I'll have a play.

The thing that is causing me the biggest problem at the moment is the difficulty in putting the results back together into something like the original array - I can't understand why the co-ordinates seem to be the same in the different patches. Do you have any ideas on that?

You could get x coordinates the same presumably, are x and y the same?

@RichardScottOZ Oh wow, I can't believe I was that silly. Yes, the y co-ordinates are different and the x co-ordinates are the same, which makes perfect sense.

I still can't seem to work out what exactly I need to do to stitch them back together, but at least now I know that I have the information required.

Sorry for wasting people's time with my mistake - but I hope the other parts of my questions are still valid.

Thanks @robintw for this very useful issue. You have not wasted anyone's time. On the contrary your feedback is very helpful. This is a very new package, totally experimental, with tons of room to improve. I hope you will keep guinea pigging and perhaps even consider contributing to the package as a developer. Xbatcher was created to scratch an itch that many of us had--you have the same itch.

You raise a lot of points in your original post. Ideally these should become separate issues eventually. But let's just continue the original discussion here for now.

  1. Generating the batches seems slow

Our approach has been to start with the API design--as users of both {pytorch / keras} and Xarray, how do we want these libraries to interact? What would be most convenient and intuitive way to exchange data? The fact that you have used xbatcher and found that "it has already solved a lot of problems" suggests we are on the right track. But of course there is more to do.

Performance is a complicated subject. In general, my approach is to first establish correctness via testing. Once we are convinced that the software is working correctly, we can think about optimizing performance. There is a pretty clear recipe for optimizing performance.

  1. Establish an easily reproducible benchmark. The example should be as simple as possible while also reproducing the performance challenge. Ideally it would use synthetic data. Perhaps something like:
  • generate a random synthetic dataset of sufficient size using random values
  • write it to disk using zarr
  • open it back up with xarray
  • use the workflow described above in part 1
  1. Profile the hell out of it. Snakeviz is my personal favorite too.
  2. Figure out where the code is spending the most of its time and try to make that part faster.
  3. Rinse and repeat

Since we have done zero performance optimization so far, there is likely some very low-hanging fruit that will be revealed by this.

An important point to raise at this stage is that Xarray and Dask in general have a performance cost compared to just using numpy. Xarray provides convenience by keeping track of the data's labels. This increases data scientist productivity. But from a computational pov, it's more expensive. Likewise, dask allows you to scale out analysis to large data. But it is almost never faster to use dask over numpy once your data fit in memory. A key question for xbatcher is--at what point do we just load data into memory eagerly? That is an area that is ripe for optimization.

  1. How do you put batches back together after processing?

This deserves its own issue. It's an important part of xbatcher that has not been implemented yet at all. We welcome your input on the API design.

... More thoughts soon.

@robintw There will be exactly zero people reading this that have not made basic x and y mistakes many times. :)

Doing this sort of thing - putting things back together [inference] is often the hard bit of this section of the workflow.

hi all :) I've been following along with the development of xbatcher and I really like the API so far. I had been thinking of writing something which uses zarr to store data batched-up and ready for being consumed by pytorch, so I though I'd have a go at adding this to xbatcher (pull-request: #40). I've added BatchGenerator.to_zarr(...) and BatchGenerator.from_zarr(...), which works like this:

da = xr.DataArray(
    np.random.rand(1000, 100, 100), name="foo", dims=["time", "y", "x"]
).chunk({"time": 1})

# use `preload_batch=False` so that we can work with data larger than memory
# otherwise all dask-arrays are turned into vanilla xr.DataArrays
bgen = xbatcher.BatchGenerator(da, {"time": 10}, preload_batch=False)
bgen.to_zarr(path="my_batched_data.zarr", chunks=dict(batch="1Gb"))
bgen_loaded = xbatcher.BatchGenerator.from_zarr(tempdir)

for batch in bgen_loaded:
   # use batches as before

By exposing the chunking in the saved datastore I was trying to follow the suggestions of @nbren12 on #2 about making the loaded chunks of batches fit in memory.

This might not be the kind of API you're going for, but I was wondering what you think?

(sorry for hijacking this issue @rabernat, I thought this might be relevant to your comments above)

I do see @rabernat's point about establishing APIs and then optimizing, but I'm not sure many will use this project without convincing performance benchmarks.

The competition for xbatcher is simple for-loop like this:

n = ds.sizes['x']

for i in range(0, n-window_size, 10):
    for j in range(0, n-window_size, 10):
        ds.isel(x=slice(i, i + window_size), y=slice(j, j+window_size))\
            .to_netcdf(f"{i}-{j}.nc)

I suspect this code is immediately clear to most xarray users...and any bugs can be quickly fixed without interacting with upstream and learning a new code base. I personally would only replace this code with an external dependency if it were much faster. Clean abstractions should be weighed against the substantial maintenance burden added by taking on a dependency like xbatcher.

I'm glad this issue is leading to some interesting discussions here - thank you everyone.

After realising that the co-ordinates were properly available in the batches, I've managed to stitch them back together using some code like this:

def concat_batch(batch):
    to_concat = []
    for batch_i in range(batch.sizes['input_batch']):
        sl = batch[batch_i, :, :]
        sl = sl.swap_dims({'y_input':'y', 'x_input':'x'})
        to_concat.append(sl)
        
    return xr.combine_by_coords(to_concat)

def stitch_patches(batches):
    batches_concatted = []
    for batch in tqdm(batches):
        batches_concatted.append(concat_batch(batch.to_array().squeeze()))

    return xr.combine_by_coords(batches_concatted)

(Interestingly, I spent ages trying to use the various array reshaping commands available in xarray, and then looked at the xbatcher code and realised it just did a load of for loops and xr.concat calls, and by taking a similar approach I solved the problem very quickly).

There are probably better ways to do this, and it is relatively slow to join them all at the moment - but at least it works. Any feedback very welcome.

It'd be interesting to think about an API for joining things back together - as the results we may have after machine learning may be of various different shapes and sizes. For my simple machine learning example I just copied the batch DataArray that was given to the machine learning algorithm and filled it with the output data - even though I was only getting a single value out of the ML for each 64x64 patch, I just filled the patch with that value. There are lots of other situations we could be in though, so it'd be good to have an API that would work with a variety of inputs.

I am having trouble keeping up with this issue because we are discussing at least three (maybe four) separate things at once in a single thread. I think we need to take some time to split this up into multiple distinct issues.

@leifdenby - thanks for #40 - it's a good idea and someone will try to review it asap and give you feedback.

@nbren12 - It is probably impossible for xbatcher to ever beat the baseline you defined above without bypassing Xarray completely. I think you know this. If that's a dealbreaker for you, no one is going to force you to use xbatcher. You seem to be implying that we should abandon the project entirely. Is that the correct interpretation of your comment?

@robintw - it would be fantastic if you could edit your original issue and split items 2, 3, 4, into distinct issues.

You seem to be implying that we should abandon the project entirely. Is that the correct interpretation of your comment?

Not exactly...just pointing out that "use a for-loop" might be a better path for folks like OP until the library matures. I'm also trying to lightly nudge you in a certain direction that I think many would find compelling. Apologies if I rubbed you the wrong way. I can be blunt, but I think a library like this has a definite path forward. Not sure why you don't see opportunities for improving the performance....maybe you just meant single threaded. I think xbatcher would be awesome if it was bundled with parallel execution frameworks like rechunker is.

Ok understood. I'm glad I asked for clarification. That's a useful perspective. I agree that performance is important and should be a goal for the project.

Yes, I see your point that there could be performance optimizations that utilize multithreading and / or async. I guess I don't have a good feel for where the bottlenecks are. My assumption has been that model training itself is the bottleneck, but that could be wrong. That's where profiling becomes very useful.

I've opened #42 to discuss the concept of benchmarking xbatcher. As I say there, I think this is particularly important as we talk about adding performance improvements.