xarray-contrib/xbatcher

Batch generation with batch_dims in v0.2.0 is about 10-20times slower

Closed this issue · 1 comments

What is your issue?

I have been using xbatcher for a few weeks now, and happened to be around when version 0.1.0 switched to version 0.2.0. I noticed that there has been a massive slow down and increase in memory usage for my work flow, after the upgrade took place.

I have identified that the problem seems to occur when using the flag batch_dims.

The minimal code example given below takes about 6s to cycle through all the batches in the for-loop when using xbatcher 0.1.0, but takes about 1min 35 s to do the same cycle for xbatcher 0.2.0.

Along with the slow down in time, the memory usage also shoots up massively when using version 0.2.0 (often leading a kernel crash, if using a small cluster). Just looking at the memory used indicator in jupyter lab, I see the memory usage hover around 2.3GB constantly for 0.1.0, but cycle up to 16GB for 0.2.0.

Minimal example code:

import xarray as xr
import numpy as np
import xbatcher

da = xr.DataArray(np.ones((100,200, 50, 200)), dims=['X', 'Y', 'Z', 'time']).rename('da')

# batch_dims option
bgen = xbatcher.BatchGenerator(ds = da, input_dims={}, batch_dims={'time':10})

%%time
nbat = 0
for batch in bgen:
    print(nbat)
    nbat += 1

Thanks for the bug report and MVCE @dhruvbalwada. I don't have a patch yet, but have an explanation and will work to get this fixed promptly. I expect we'll make a patch release afterwards.

In v0.1.0, the batches were generated lazily but many features were missing (e.g,. __getitem__, __len__, in addition to the data loaders). In order to allow arbitrary batch access, #25 made batch generation eager. #112 generally improved performance by supporting lazy batch generation along with the features added since v0.1.0. The strategy for lazy batch generation is to store the indices of the batches rather than xarray data objects themselves. Your bug report reveals that only the indices associated with dimensions in input_dims are appropriately represented in the selector object, but not the indices of dimensions in batch_dims. The consequence is that batch_dims only impacts the number of batches returned, but not the size of the DataArray returned as a batch.

Two requirements for resolving this issue are:

  • Add a test for batch_dims (which should pass on v0.1.0 and fail on v0.2.0) (xref #83)
  • Modify the batches attribute to store indices associated with both batch_dims and input_dims