Ensure pyramids have the right attributes/metadata
Closed this issue · 1 comments
andersy005 commented
Over in https://github.com/carbonplan/cmip6-downscaling/blob/main/cmip6_downscaling/methods/common/tasks.py, we have a utility function that is used to post-process the produced pyramid. This function is used to add metadata/attributes to the pyramid. We've talked about including this function in ndpyramid
directly. I'm opening this issue for clarity/future reference.
def _pyramid_postprocess(dt: dt.DataTree, levels: int, other_chunks: dict = None) -> dt.DataTree:
'''Postprocess data pyramid
Adds multiscales metadata and sets Zarr encoding
Parameters
----------
dt : dt.DataTree
Input data pyramid
levels : int
Number of levels in pyramid
other_chunks : dict
Chunks for non-spatial dims
Returns
-------
dt.DataTree
Updated data pyramid with metadata / encoding set
'''
chunks = {"x": PIXELS_PER_TILE, "y": PIXELS_PER_TILE}
if other_chunks is not None:
chunks.update(other_chunks)
for level in range(levels):
slevel = str(level)
dt.ds.attrs['multiscales'][0]['datasets'][level]['pixels_per_tile'] = PIXELS_PER_TILE
# set dataset chunks
dt[slevel].ds = dt[slevel].ds.chunk(chunks)
if 'date_str' in dt[slevel].ds:
dt[slevel].ds['date_str'] = dt[slevel].ds['date_str'].chunk(-1)
# set dataset encoding
dt[slevel].ds = set_zarr_encoding(
dt[slevel].ds, codec_config={"id": "zlib", "level": 1}, float_dtype="float32"
)
for var in ['time', 'time_bnds']:
if var in dt[slevel].ds:
dt[slevel].ds[var].encoding['dtype'] = 'int32'
# set global metadata
dt.ds.attrs.update({'title': 'multiscale data pyramid'}, **get_cf_global_attrs(version=version))
return dt
andersy005 commented
The set_zarr_encoding
function resides in https://github.com/carbonplan/data/blob/5ba3b28de96b206ae2cdb0bb5a84b91a2f75e034/carbonplan_data/utils.py#L184
def set_zarr_encoding(
ds: xr.Dataset,
codec_config: dict | None = None,
float_dtype: DTypeLike | None = None,
int_dtype: DTypeLike | None = None,
) -> xr.Dataset:
"""Set zarr encoding for each variable in the dataset
Parameters
----------
ds : xr.Dataset
Input dataset
codec_config : dict, optional
Dictionary of parameters to pass to numcodecs.get_codec, default is {'id': 'zlib', 'level': 1}
float_dtype : str or dtype, optional
Dtype to cast floating point variables to
Returns
-------
ds : xr.Dataset
Output dataset with updated variable encodings
"""
import numcodecs
ds = ds.copy()
if codec_config is None:
codec_config = {"id": "zlib", "level": 1}
compressor = numcodecs.get_codec(codec_config)
for k, da in ds.variables.items():
# maybe cast float type
if np.issubdtype(da.dtype, np.floating) and float_dtype is not None:
da = da.astype(float_dtype)
if np.issubdtype(da.dtype, np.integer) and int_dtype is not None:
da = da.astype(int_dtype)
# remove old encoding
da.encoding.clear()
# update with new encoding
da.encoding["compressor"] = compressor
try:
del da.atrrs["_FillValue"]
except AttributeError:
pass
da.encoding["_FillValue"] = default_fillvals.get(
da.dtype.str[-2:], None
) # TODO: handle date/time types
ds[k] = da
return ds