H2Oxford/h2ox-ai

Error in `_reshape_data_ohe`

Closed this issue · 2 comments

Haven't had time to fully investigate:

(ml) ➜  h2ox-ai git:(main) ipython --pdb tests/test_dataset.py
2022-02-09 14:22:45.766 | INFO     | h2ox.ai.dataset.dataset:engineer_arrays:199 - soft data transforms - maybe normalise
2022-02-09 14:22:46.060 | INFO     | h2ox.ai.dataset.dataset:engineer_arrays:205 - soft data transforms - validate datetimes
2022-02-09 14:22:57.748 | INFO     | h2ox.ai.dataset.dataset:engineer_arrays:208 - soft data transforms - interpolate_1d
2022-02-09 14:22:58.371 | INFO     | h2ox.ai.dataset.dataset:engineer_arrays:214 - soft data transforms - reshape with one-hot-encoding
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/github/h2ox/h2ox-ai/tests/test_dataset.py in <module>
     39         ds = create_dummy_data()
     40
---> 41     data = FcastDataset(
     42         ds,
     43         **cfg["dataset_parameters"],

~/github/h2ox/h2ox-ai/h2ox/ai/dataset/dataset.py in __init__(self, data, select_sites, historical_seq_len, forecast_horizon, future_horizon, target_var, historic_variables, forecast_variables, future_variables, max_consecutive_nan, ohe_or_multi, normalise, time_dim, horizon_dim, **kwargs)
     64
     65         # turn the data into a dictionary for model input
---> 66         self.engineer_arrays(data)
     67
     68     def __len__(self) -> int:

~/github/h2ox/h2ox-ai/h2ox/ai/dataset/dataset.py in engineer_arrays(self, data)
    213         elif self.ohe_or_multi == "ohe":
    214             logger.info("soft data transforms - reshape with one-hot-encoding")
--> 215             historic, forecast, future, targets = self._reshape_data_ohe(data)
    216
    217         idx = historic["date-site"].data[np.isin(historic["date"].data, valid_dates)]

~/github/h2ox/h2ox-ai/h2ox/ai/dataset/dataset.py in _reshape_data_ohe(self, data)
    161
    162         future = self._onehotencode(
--> 163             self._get_future_data(data), "steps"
    164         )  # DATES*sites x STEPS x var+ohe
    165

~/github/h2ox/h2ox-ai/h2ox/ai/dataset/dataset.py in _get_future_data(self, data)
    319
    320         return (
--> 321             data[self.future_variables]
    322             .to_array()
    323             .sel({"steps": future_period})

~/miniconda3/envs/ml/lib/python3.9/site-packages/xarray/core/dataarray.py in sel(self, indexers, method, tolerance, drop, **indexers_kwargs)
   1330         Dimensions without coordinates: points
   1331         """
-> 1332         ds = self._to_temp_dataset().sel(
   1333             indexers=indexers,
   1334             drop=drop,

~/miniconda3/envs/ml/lib/python3.9/site-packages/xarray/core/dataset.py in sel(self, indexers, method, tolerance, drop, **indexers_kwargs)
   2502         """
   2503         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
-> 2504         pos_indexers, new_indexes = remap_label_indexers(
   2505             self, indexers=indexers, method=method, tolerance=tolerance
   2506         )

~/miniconda3/envs/ml/lib/python3.9/site-packages/xarray/core/coordinates.py in remap_label_indexers(obj, indexers, method, tolerance, **indexers_kwargs)
    419     }
    420
--> 421     pos_indexers, new_indexes = indexing.remap_label_indexers(
    422         obj, v_indexers, method=method, tolerance=tolerance
    423     )

~/miniconda3/envs/ml/lib/python3.9/site-packages/xarray/core/indexing.py in remap_label_indexers(data_obj, indexers, method, tolerance)
    118     for dim, index in indexes.items():
    119         labels = grouped_indexers[dim]
--> 120         idxr, new_idx = index.query(labels, method=method, tolerance=tolerance)
    121         pos_indexers[dim] = idxr
    122         if new_idx is not None:

~/miniconda3/envs/ml/lib/python3.9/site-packages/xarray/core/indexes.py in query(self, labels, method, tolerance)
    240                 indexer = get_indexer_nd(self.index, label, method, tolerance)
    241                 if np.any(indexer < 0):
--> 242                     raise KeyError(f"not all values found in index {coord_name!r}")
    243
    244         return indexer, None

KeyError: "not all values found in index 'steps'"
> /Users/tommylees/miniconda3/envs/ml/lib/python3.9/site-packages/xarray/core/indexes.py(242)query()
    240                 indexer = get_indexer_nd(self.index, label, method, tolerance)
    241                 if np.any(indexer < 0):
--> 242                     raise KeyError(f"not all values found in index {coord_name!r}")
    243
    244         return indexer, None

ipdb>

But confirmed on both Remote Gcloud instance and local instance (macbook air)

should be a simple fix just leaving here so we're both aware

Notice that the steps in the data created by the data factory are: [0, ..., 89]

ipdb> data
<xarray.Dataset>
Dimensions:                 (steps: 90, date: 4384, global_sites: 6)
Coordinates:
  * steps                   (steps) timedelta64[ns] 0 days 1 days ... 89 days
  * date                    (date) datetime64[ns] 2010-01-01 ... 2022-01-01
  * global_sites            (global_sites) object 'bhadra' ... 'lower_bhawani'
Data variables:
    doy_sin                 (global_sites, steps, date) float64 0.01721 ... 0...
    doy_cos                 (global_sites, steps, date) float64 0.9999 ... 0....
    forecast_t2m            (date, global_sites, steps) float64 0.3341 ... nan
    forecast_tp             (date, global_sites, steps) float64 -0.5631 ... nan
    historic_t2m            (steps, date, global_sites) float64 0.06162 ... nan
    historic_tp             (steps, date, global_sites) float64 -0.2409 ... nan
    targets_WATER_VOLUME    (steps, global_sites, date) float64 1.317 ... nan
    targets_RESERVOIR_NAME  (steps, global_sites, date) object 'bhadra' ... ''