Apply interpolation
mmann1123 opened this issue · 4 comments
@jgrss I have been working on alternative ways to interpolate missing values on a time series. The apply
method initially seemed like a good option. However it seems to only be able to write a single date back out. Hence the use of return array[self.index_to_write].squeeze()
.
Just wondering if there is another way I could apply the function as quickly. Or if I need to use something like map_blocks
instead (which is much slower I think).
# not working because I can't write out multiple observations
def _interpolate_nans_linear(array):
if all(np.isnan(array)):
return array
else:
return np.interp(
np.arange(len(array)),
np.arange(len(array))[np.isnan(array) == False],
array[np.isnan(array) == False],
)
class interpolate_nan(gw.TimeModule):
"""Interpolate missing values in the time series using linear interpolation.
Args:
gw (_type_): _description_
missing_value (int, optional): The value to be replaced by NaNs. Default is None.
interp_type (str, optional): The type of interpolation to use. Default is "linear".
index_to_write (int, optional): The index of the interpolated array to return. Default is 0.
"""
def __init__(self, missing_value=None, interp_type="linear", index_to_write=0):
super(interpolate_nan, self).__init__()
self.missing_value = missing_value
self.interp_type = interp_type
self.index_to_write = index_to_write
def calculate(self, array):
# check if missing_value is not None and not np.nan
if self.missing_value is not None:
if not np.isnan(self.missing_value):
array = jnp.where(array == self.missing_value, np.NaN, array)
if self.interp_type == "linear":
array = np.apply_along_axis(_interpolate_nans_linear, axis=0, arr=array)
# Return one of the interpolated arrays base on the index_to_write
return array[self.index_to_write].squeeze() #
with gw.series(
files,
nodata=9999,
) as src:
src.apply(
func=interpolate_nan(missing_value=0),
outfile=f"/home/mmann1123/Downloads/test.tif",
num_workers=5,
bands=1,
)
Hey @mmann1123 I think I follow your setup, except for the shape of your data. What is the shape of array[self.index_to_write]
? Is array
4d (time x bands x height x width) and then you slice to get (1 x bands (1?) x height x width), and then squeeze to (time x height x width)?
You can control a bit of the output profile in your user function. For example, the output band count is set by the TimeModule.count
(by default, the output band count is 1). And that gets passed here to the rasterio
profile.
If you want a multi-band output then you can specify that in your user function. I don't think you need index_to_write
, so I replaced it with count
below in your __init__
method and in the return. But I think that assumes you are processing multi-temporal, single band data. See my comment at the bottom.
class interpolate_nan(gw.TimeModule):
def __init__(self, missing_value=None, interp_type="linear", count=1):
super(interpolate_nan, self).__init__()
self.missing_value = missing_value
self.interp_type = interp_type
# Overrides the default output band count
self.count = count
def calculate(self, array):
# check if missing_value is not None and not np.nan
if self.missing_value is not None:
if not np.isnan(self.missing_value):
array = jnp.where(array == self.missing_value, np.NaN, array)
if self.interp_type == "linear":
array = np.apply_along_axis(_interpolate_nans_linear, axis=0, arr=array)
# Return the interpolated array (3d -> time/bands x height x width)
# If the array is (time x 1 x height x width) then squeeze to 3d
return array.squeeze()
Then, you should be able to use it by:
with gw.series(
files,
nodata=9999,
) as src:
src.apply(
func=interpolate_nan(
missing_value=0,
# not sure if your output length matches your input file length
# whatever your case is, this is where you define the output band count
count=len(src.filenames)
),
outfile=f"/home/mmann1123/Downloads/test.tif",
num_workers=5,
# Note that this is the band, or bands, to read
bands=1,
)
Note that you can only write a 3d array. Therefore, you can either write a single interpolated date and multiple bands, or all the dates for a single band.
Did you also try xarray's interpolate method?
with gw.open(files, chunks={'time': -1}) as src:
interp = (
src.interpolate_na(dim='time', method='linear', fill_value='extrapolate')
.bfill(dim='time')
.ffill(dim='time')
# Interpolate to new grid
#.interp(time=smooth_range, method='slinear')
)
I had been using xarrays interpolate_na
but now I have files that are too big to bring into memory, and I am not sure how to apply it to chunks and write it out.
Ok you are a life saver as always. Yes I was applying the interpolation to a single band for multiple periods. Your apply example worked! Thanks again