E3SM-Project/e3sm_diags

Replace climatology implementation to use xCDAT

Opened this issue · 0 comments

In #658, I included an implementation of the climo() function that is based on xCDAT. I found large floating point differences when I comparing the climatology for land sea masked calculated using the old climo() and the xCDAT climo(),

I have not pinpointed the exact cause yet. I decided to implement a climo() function that uses the old logic to operate on Xarray objects instead of sinking time investigating. In a separate PR, we should investigate this more closely and consider using the xCDAT climatology API once we can get close results.

Started implementation:

"""The climatology function implemented using xCDAT's climatology API.
NOTE: This function has not been incorporated into the codebase yet because
further investigation is needed to figure out why there are large floating
point differences compared to E3SM Diags' climo function (climo.py and
climo_xr.py).
"""
import xarray as xr
import xcdat as xc

from e3sm_diags.driver.utils.climo_xr import CLIMO_FREQ, CLIMO_FREQS


def climo(dataset: xr.Dataset, var_key: str, freq: CLIMO_FREQ) -> xr.DataArray:
    """Computes a variable's climatology for the given frequency.
    xCDAT's climatology API operates on a data variable within an `xr.Dataset`
    object by specifying the key of the data variable. It uses time bounds to
    redefine time as the midpoint between bounds values and month lengths for
    proper weighting.
    After averaging, the data variable's time dimension is squeezed and the
    time coordinates because they become singletons.
    Parameters
    ----------
    dataset: xr.Dataset
        The dataset containing the data variable
    data_var : xr.DataArray
        The data variable.
    freq : CLIMO_FREQ
        The frequency for calculating climatology.
    Returns
    -------
    xr.DataArray
        The variable's climatology.
    """
    ds = dataset.copy()
    ds = xc.center_times(ds)

    # The variable's time dim key is stored here for reuse in subsetting.
    time_dim = xc.get_dim_keys(ds[var_key], axis="T")

    if freq in ["ANN"]:
        ds_climo = ds.temporal.average(var_key, weighted=True)
    elif freq in [
        "01",
        "02",
        "03",
        "04",
        "05",
        "06",
        "07",
        "08",
        "09",
        "10",
        "11",
        "12",
    ]:
        ds = ds.isel({f"{time_dim}": (ds[time_dim].dt.month == int(freq))})
        ds_climo = ds.temporal.climatology(var_key, freq="month", weighted=True)
    elif freq in ["DJF", "MAM", "JJA", "SON"]:
        ds = ds.isel({f"{time_dim}": (ds[time_dim].dt.season == freq)})
        ds_climo = ds.temporal.climatology(var_key, freq="season", weighted=True)
    else:
        raise ValueError(
            f"`freq='{freq}'` is not a valid climatology frequency. Options "
            f"include {CLIMO_FREQS}'"
        )

    dv_climo = ds_climo[var_key].copy()

    # The time dimension should be a singleton after averaging so it should
    # be squeezed and dropped from the climatology data variable.
    if time_dim in dv_climo.dims:
        dv_climo = dv_climo.squeeze(dim=time_dim).drop_vars(time_dim)

    return dv_climo
from pathlib import Path

import numpy as np
import pytest
import xarray as xr

from e3sm_diags.driver.utils.climo_xcdat import climo


class TestClimo:
    @pytest.fixture(autouse=True)
    def setup(self, tmp_path: Path):
        # Create temporary directory to save files.
        dir = tmp_path / "input_data"
        dir.mkdir()

        self.ds = xr.Dataset(
            data_vars={
                "ts": xr.DataArray(
                    data=np.array(
                        [[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]], dtype="float64"
                    ),
                    dims=["time", "lat", "lon"],
                    attrs={"test_attr": "test"},
                ),
                "time_bnds": xr.DataArray(
                    name="time_bnds",
                    data=np.array(
                        [
                            [
                                "2000-01-01T00:00:00.000000000",
                                "2000-02-01T00:00:00.000000000",
                            ],
                            [
                                "2000-03-01T00:00:00.000000000",
                                "2000-04-01T00:00:00.000000000",
                            ],
                            [
                                "2000-06-01T00:00:00.000000000",
                                "2000-07-01T00:00:00.000000000",
                            ],
                            [
                                "2000-09-01T00:00:00.000000000",
                                "2000-10-01T00:00:00.000000000",
                            ],
                            [
                                "2001-02-01T00:00:00.000000000",
                                "2001-03-01T00:00:00.000000000",
                            ],
                        ],
                        dtype="datetime64[ns]",
                    ),
                    dims=["time", "bnds"],
                    attrs={"xcdat_bounds": "True"},
                ),
            },
            coords={
                "lat": xr.DataArray(
                    data=np.array([-90]),
                    dims=["lat"],
                    attrs={
                        "axis": "Y",
                        "long_name": "latitude",
                        "standard_name": "latitude",
                    },
                ),
                "lon": xr.DataArray(
                    data=np.array([0]),
                    dims=["lon"],
                    attrs={
                        "axis": "X",
                        "long_name": "longitude",
                        "standard_name": "longitude",
                    },
                ),
                "time": xr.DataArray(
                    data=np.array(
                        [
                            "2000-01-16T12:00:00.000000000",
                            "2000-03-16T12:00:00.000000000",
                            "2000-06-16T00:00:00.000000000",
                            "2000-09-16T00:00:00.000000000",
                            "2001-02-15T12:00:00.000000000",
                        ],
                        dtype="datetime64[ns]",
                    ),
                    dims=["time"],
                    attrs={
                        "axis": "T",
                        "long_name": "time",
                        "standard_name": "time",
                        "bounds": "time_bnds",
                    },
                ),
            },
        )
        self.ds.time.encoding = {
            "units": "years since 2000-01-01",
            "calendar": "standard",
        }

    def test_returns_annual_cycle_climatology(self):
        ds = self.ds.copy()

        result = climo(ds, "ts", "ANN")
        expected = xr.DataArray(
            name="ts",
            data=np.array([[1.4]]),
            dims=["lat", "lon"],
            coords={"lat": ds.lat, "lon": ds.lon},
            attrs={
                "test_attr": "test",
                "operation": "temporal_avg",
                "mode": "average",
                "freq": "month",
                "weighted": "True",
            },
        )

        # Check DataArray values and attributes align
        assert result.identical(expected)

    def test_returns_DJF_season_climatology(self):
        ds = self.ds.copy()

        result = climo(ds, "ts", "DJF")
        expected = xr.DataArray(
            name="ts",
            data=np.array([[2.0]]),
            coords={
                "lat": ds.lat,
                "lon": ds.lon,
            },
            dims=["lat", "lon"],
            attrs={
                "test_attr": "test",
                "operation": "temporal_avg",
                "mode": "climatology",
                "freq": "season",
                "weighted": "True",
                "dec_mode": "DJF",
                "drop_incomplete_djf": "False",
            },
        )

        # Check DataArray values and attributes align
        assert result.identical(expected)

    def test_returns_MAM_season_climatology(self):
        ds = self.ds.copy()

        result = climo(ds, "ts", "MAM")
        expected = xr.DataArray(
            name="ts",
            data=np.array([[1.0]]),
            coords={
                "lat": ds.lat,
                "lon": ds.lon,
            },
            dims=["lat", "lon"],
            attrs={
                "test_attr": "test",
                "operation": "temporal_avg",
                "mode": "climatology",
                "freq": "season",
                "dec_mode": "DJF",
                "drop_incomplete_djf": "False",
                "weighted": "True",
            },
        )

        # Check DataArray values and attributes align
        assert result.identical(expected)

    def test_returns_JJA_season_climatology(self):
        ds = self.ds.copy()

        result = climo(ds, "ts", "JJA")
        expected = xr.DataArray(
            name="ts",
            data=np.array([[1.0]]),
            coords={
                "lat": ds.lat,
                "lon": ds.lon,
            },
            dims=["lat", "lon"],
            attrs={
                "test_attr": "test",
                "operation": "temporal_avg",
                "mode": "climatology",
                "freq": "season",
                "dec_mode": "DJF",
                "drop_incomplete_djf": "False",
                "weighted": "True",
            },
        )

        # Check DataArray values and attributes align
        assert result.identical(expected)

    def test_returns_SON_season_climatology(self):
        ds = self.ds.copy()

        result = climo(ds, "ts", "SON")
        expected = xr.DataArray(
            name="ts",
            data=np.array([[1.0]]),
            coords={
                "lat": ds.lat,
                "lon": ds.lon,
            },
            dims=["lat", "lon"],
            attrs={
                "test_attr": "test",
                "operation": "temporal_avg",
                "mode": "climatology",
                "freq": "season",
                "dec_mode": "DJF",
                "drop_incomplete_djf": "False",
                "weighted": "True",
            },
        )

        # Check DataArray values and attributes align
        assert result.identical(expected)

    def test_returns_jan_climatology(self):
        ds = self.ds.copy()

        result = climo(ds, "ts", "01")
        expected = xr.DataArray(
            name="ts",
            data=np.array([[2.0]]),
            coords={
                "lat": ds.lat,
                "lon": ds.lon,
            },
            dims=["lat", "lon"],
            attrs={
                "test_attr": "test",
                "operation": "temporal_avg",
                "mode": "climatology",
                "freq": "month",
                "weighted": "True",
            },
        )

        # Check DataArray values and attributes align
        assert result.identical(expected)