carbonplan/ndpyramid

Ensure pyramids have the right attributes/metadata

Closed this issue · 1 comments

Over in https://github.com/carbonplan/cmip6-downscaling/blob/main/cmip6_downscaling/methods/common/tasks.py, we have a utility function that is used to post-process the produced pyramid. This function is used to add metadata/attributes to the pyramid. We've talked about including this function in ndpyramid directly. I'm opening this issue for clarity/future reference.

def _pyramid_postprocess(dt: dt.DataTree, levels: int, other_chunks: dict = None) -> dt.DataTree:
    '''Postprocess data pyramid
    Adds multiscales metadata and sets Zarr encoding
    Parameters
    ----------
    dt : dt.DataTree
        Input data pyramid
    levels : int
        Number of levels in pyramid
    other_chunks : dict
        Chunks for non-spatial dims
    Returns
    -------
    dt.DataTree
        Updated data pyramid with metadata / encoding set
    '''
    chunks = {"x": PIXELS_PER_TILE, "y": PIXELS_PER_TILE}
    if other_chunks is not None:
        chunks.update(other_chunks)

    for level in range(levels):
        slevel = str(level)
        dt.ds.attrs['multiscales'][0]['datasets'][level]['pixels_per_tile'] = PIXELS_PER_TILE

        # set dataset chunks
        dt[slevel].ds = dt[slevel].ds.chunk(chunks)
        if 'date_str' in dt[slevel].ds:
            dt[slevel].ds['date_str'] = dt[slevel].ds['date_str'].chunk(-1)

        # set dataset encoding
        dt[slevel].ds = set_zarr_encoding(
            dt[slevel].ds, codec_config={"id": "zlib", "level": 1}, float_dtype="float32"
        )
        for var in ['time', 'time_bnds']:
            if var in dt[slevel].ds:
                dt[slevel].ds[var].encoding['dtype'] = 'int32'

    # set global metadata
    dt.ds.attrs.update({'title': 'multiscale data pyramid'}, **get_cf_global_attrs(version=version))
    return dt

Cc @jhamman, @norlandrhagen

The set_zarr_encoding function resides in https://github.com/carbonplan/data/blob/5ba3b28de96b206ae2cdb0bb5a84b91a2f75e034/carbonplan_data/utils.py#L184

def set_zarr_encoding(
    ds: xr.Dataset,
    codec_config: dict | None = None,
    float_dtype: DTypeLike | None = None,
    int_dtype: DTypeLike | None = None,
) -> xr.Dataset:
    """Set zarr encoding for each variable in the dataset
    Parameters
    ----------
    ds : xr.Dataset
        Input dataset
    codec_config : dict, optional
        Dictionary of parameters to pass to numcodecs.get_codec, default is {'id': 'zlib', 'level': 1}
    float_dtype : str or dtype, optional
        Dtype to cast floating point variables to
    Returns
    -------
    ds : xr.Dataset
        Output dataset with updated variable encodings
    """
    import numcodecs

    ds = ds.copy()

    if codec_config is None:
        codec_config = {"id": "zlib", "level": 1}
    compressor = numcodecs.get_codec(codec_config)

    for k, da in ds.variables.items():

        # maybe cast float type
        if np.issubdtype(da.dtype, np.floating) and float_dtype is not None:
            da = da.astype(float_dtype)

        if np.issubdtype(da.dtype, np.integer) and int_dtype is not None:
            da = da.astype(int_dtype)

        # remove old encoding
        da.encoding.clear()

        # update with new encoding
        da.encoding["compressor"] = compressor
        try:
            del da.atrrs["_FillValue"]
        except AttributeError:
            pass
        da.encoding["_FillValue"] = default_fillvals.get(
            da.dtype.str[-2:], None
        )  # TODO: handle date/time types

        ds[k] = da

    return ds