google-research/weatherbench2

Possible bug with xarray dataset selection within beam

jdwillard19 opened this issue · 2 comments

I am running "evaluate_with_beam" and it doesn't appear to be able to select variables in the observations dataset.

This is the line causing the error

If I load the same dataset and execute the line in question outside of the beam context using the same python container environment, it works just fine.

Also, if I go into the _impose_data_selection() and manually do

print("Variables in dataset:", list(dataset.data_vars.keys()))

it outputs the same list I am trying to select

['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', 'geopotential', 'mean_sea_level_pressure', 'specific_humidity', 'surface_pressure', 'temperature', 'total_column_water_vapour', 'u_component_of_wind', 'v_component_of_wind']

code snippet

    def score_deterministic(self):
        import weatherbench2
        from weatherbench2.metrics import MSE, ACC, MAE, Bias
        from weatherbench2 import config as wb2_config
        from weatherbench2 import evaluation as wb2_evaluation
        if self.log_to_screen:
            logging.info("Beginning scoring with WB2....")
        #define WB2 configs
        paths = wb2_config.Paths(
                            forecast=self.forecast_output_path,
                            obs=self.obs_path,
                            output_dir=self.inf_dir,  
                            climatology=self.climatology_path
                       )
        selection = wb2_config.Selection(
                                    variables=self.score_variables,
                                    levels=self.score_levels,
                                    time_slice=slice(self.score_start_date, self.score_end_date))
        
        data_config = wb2_config.Data(selection=selection, paths=paths)
        
        climatology = None
        if self.climatology_path:
            climatology = xr.open_zarr(self.climatology_path)
        
        metrics = {}
        if 'mse' in self.score_metrics:
            metrics['mse'] = MSE()
        if 'acc' in self.score_metrics:
            if climatology is not None:
                metrics['acc'] = ACC(climatology=climatology)
            else:
                raise ValueError("Climatology path must be provided if 'acc' metric is specified.")
        if 'mae' in self.score_metrics:
            metrics['mae'] = MAE()
        if 'bias' in self.score_metrics:
            metrics['bias'] = Bias()


        regions = {}
        if 'global' in self.score_regions:
            regions['global'] = weatherbench2.regions.SliceRegion()
        if 'tropics' in self.score_regions:
            regions['tropics'] = weatherbench2.regions.SliceRegion(lat_slice=slice(-20, 20))
        if 'extra-tropics' in self.score_regions:
            regions['extra-tropics'] = weatherbench2.regions.ExtraTropicalRegion(),
        
        # Create the eval_configs dictionary
        eval_config = {
            'deterministic': wb2_config.Eval(metrics=metrics,
                                                       regions=regions,
                                                       evaluate_persistence=self.evaluate_persistence,
                                                       evaluate_climatology=self.evaluate_climatology)
        }
        
        if self.use_beam:
            direct_runner_options = [
                    f'--direct_num_workers={self.direct_num_workers}',
                    '--direct_running_mode=multi_processing',
            ]

            # Combine existing argv with the new DirectRunner options
            argv = []
            argv.extend(direct_runner_options)
            wb2_evaluation.evaluate_with_beam(
                data_config,
                eval_config,
                runner='DirectRunner',
                input_chunks={'init_time': 1, 'lead_time': 1},
                fanout=self.fanout,
                argv=argv
            )
        else:
            wb2_evaluation.evaluate_in_memory(data_config, eval_config)

Stack trace

Traceback (most recent call last):
File "/global/u2/j/jwillard/healpixnat/inference.py", line 22, in
run()
File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in
lambda: hydra.run(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/global/u2/j/jwillard/healpixnat/inference.py", line 13, in run
inferencer.launch()
File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 108, in launch
self.build_and_run()
File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 140, in build_and_run
self.inference()
File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 339, in inference
self.score_deterministic()
File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 258, in score_deterministic
wb2_evaluation.evaluate_with_beam(
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 824, in evaluate_with_beam
root
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/transforms/ptransform.py", line 1110, in ror
return self.transform.ror(pvalueish, self.label)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/transforms/ptransform.py", line 623, in ror
result = p.apply(self, pvalueish, label)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/pipeline.py", line 679, in apply
return self.apply(transform, pvalueish)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/pipeline.py", line 732, in apply
pvalueish_result = self.runner.apply(transform, pvalueish, self._options)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/runners/runner.py", line 203, in apply
return self.apply_PTransform(transform, input, options)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/runners/runner.py", line 207, in apply_PTransform
return transform.expand(input)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 773, in expand
forecast, truth, climatology = open_forecast_and_truth_datasets(
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 376, in open_forecast_and_truth_datasets
obs_all_times = _impose_data_selection(
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 170, in _impose_data_selection
dataset = dataset[sel_variables].sel(
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/xarray/core/dataset.py", line 1484, in getitem
return self._construct_dataarray(key)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/xarray/core/dataset.py", line 1395, in _construct_dataarray
_, name, variable = _get_virtual_variable(self._variables, name, self.dims)
File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/xarray/core/dataset.py", line 192, in _get_virtual_variable
raise KeyError(key)
KeyError: ['geopotential', 'temperature', 'u_component_of_wind', 'v_component_of_wind', 'specific_humidity', '2m_temperature', '10m_u_component_of_wind', '10m_v_component_of_wind', 'mean_sea_level_pressure']

Environment

Package Version Editable project location


absl-py 2.1.0
aiobotocore 2.13.0
aiohttp 3.9.5
aioitertools 0.11.0
aiosignal 1.3.1
annotated-types 0.6.0
antlr4-python3-runtime 4.9.3
apache-beam 2.57.0
apex 0.1
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
asciitree 0.3.3
astropy 6.1.0
astropy-iers-data 0.2024.6.10.0.30.47
asttokens 2.4.1
astunparse 1.6.3
async-timeout 4.0.3
attrs 23.2.0
audioread 3.0.1
beautifulsoup4 4.12.3
black 24.4.2
bleach 6.1.0
blis 0.7.11
bokeh 3.5.0
botocore 1.34.106
cachetools 5.3.3
cads-api-client 1.0.3
Cartopy 0.23.0
catalogue 2.0.10
cdsapi 0.7.0
certifi 2024.2.2
cffi 1.16.0
cftime 1.6.4
charset-normalizer 3.3.2
click 8.1.7
cloudpathlib 0.16.0
cloudpickle 2.2.1
cmake 3.29.2
comm 0.2.2
confection 0.1.4
contourpy 1.2.1
crcmod 1.7
cuda-python 12.4.0
cudf 24.4.0
cugraph 24.4.0
cugraph-dgl 24.4.0
cugraph-equivariant 24.4.0
cugraph-pyg 24.4.0
cugraph-service-client 24.4.0
cugraph-service-server 24.4.0
cuml 24.4.0
cupy-cuda12x 13.0.0
cycler 0.12.1
cymem 2.0.8
Cython 3.0.10
dask 2024.1.1
dask-cuda 24.4.0
dask-cudf 24.4.0
dask-expr 0.4.0
debugpy 1.8.1
decorator 5.1.1
deepspeed 0.14.3
defusedxml 0.7.1
dgl 2.2.1
dill 0.3.1.1
distributed 2024.1.1
dm-tree 0.1.8
dnspython 2.6.1
docker-pycreds 0.4.0
docopt 0.6.2
earth2-grid 2024.5.2
einops 0.8.0
entrypoints 0.4
exceptiongroup 1.2.1
execnet 2.1.1
executing 2.0.1
expecttest 0.1.3
fastavro 1.9.5
fasteners 0.19
fastjsonschema 2.19.1
fastrlock 0.8.2
filelock 3.14.0
flash-attn 2.4.2
fonttools 4.51.0
frozenlist 1.4.1
fsspec 2024.6.0
gast 0.5.4
gitdb 4.0.11
GitPython 3.1.43
gnureadline 8.2.10
google-auth 2.29.0
google-auth-oauthlib 0.4.6
grpcio 1.64.1
h5py 3.11.0
hdfs 2.7.3
healpy 1.17.1
hjson 3.1.0
httplib2 0.22.0
huggingface-hub 0.23.3
hydra-core 1.3.2
hypothesis 5.35.1
idna 3.7
igraph 0.11.4
imageio 2.34.1
immutabledict 4.2.0
importlib_metadata 7.1.0
iniconfig 2.0.0
intel-openmp 2021.4.0
ipykernel 6.29.4
ipython 8.21.0
ipython-genutils 0.2.0
jax 0.4.30
jaxlib 0.4.30
jedi 0.19.1
Jinja2 3.1.3
jmespath 1.0.1
joblib 1.4.0
Js2Py 0.74
json5 0.9.25
jsonpickle 3.2.2
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
jupyter_client 8.6.1
jupyter_core 5.7.2
jupyter-tensorboard 0.2.0
jupyterlab 2.3.2
jupyterlab_pygments 0.3.0
jupyterlab-server 1.2.0
jupytext 1.16.1
kiwisolver 1.4.5
kvikio 24.4.0
langcodes 3.4.0
language_data 1.2.0
lark 1.1.9
lazy_loader 0.4
librosa 0.10.1
lightning-thunder 0.2.0.dev0
lightning-utilities 0.11.2
littleutils 0.2.2
llvmlite 0.42.0
locket 1.0.0
looseversion 1.3.0
marisa-trie 1.1.0
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.8.4
matplotlib-inline 0.1.7
mdit-py-plugins 0.4.0
mdurl 0.1.2
mistune 3.0.2
mkl 2021.1.1
mkl-devel 2021.1.1
mkl-include 2021.1.1
ml-dtypes 0.4.0
mock 5.1.0
mpi4py 3.1.6
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiurl 0.3.1
murmurhash 1.0.10
mypy-extensions 1.0.0
natten 0.17.1 /opt/NATTEN/src
nbclient 0.10.0
nbconvert 7.16.4
nbformat 5.10.4
nest-asyncio 1.6.0
netCDF4 1.6.5
networkx 3.3
ninja 1.11.1.1
notebook 6.4.10
numba 0.59.1
numcodecs 0.11.0
numpy 1.24.4
nvfuser 0.2.0a0+0ff5802
nvidia-cudnn-frontend 1.3.0
nvidia-dali-cuda110 1.38.0
nvidia-dali-cuda120 1.37.1
nvidia-ml-py 12.555.43
nvidia-modulus 0.3.0
nvidia-nvimgcodec-cu11 0.2.0.7
nvidia-nvimgcodec-cu12 0.2.0.7
nvidia-pyindex 1.0.9
nvtx 0.2.5
nx-cugraph 24.4.0
oauthlib 3.2.2
objsize 0.7.0
ogb 1.3.6
omegaconf 2.3.0
onnx 1.16.0
opencv 4.7.0
opt-einsum 3.3.0
optree 0.11.0
orjson 3.10.6
outdated 0.2.2
packaging 24.0
pandas 2.0.3
pandocfilters 1.5.1
parso 0.8.4
partd 1.4.1
pathspec 0.12.1
pexpect 4.9.0
Pillow 9.5.0
pip 24.0
platformdirs 4.2.1
pluggy 1.5.0
ply 3.11
polygraphy 0.49.10
pooch 1.8.1
preshed 3.0.9
prettytable 3.10.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
properscoring 0.1
proto-plus 1.24.0
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
pyarrow 14.0.2
pyarrow-hotfix 0.6
pyasn1 0.6.0
pyasn1_modules 0.4.0
pybind11 2.12.0
pybind11_global 2.12.0
pycocotools 2.0+nv0.8.0
pycparser 2.22
pydantic 2.7.1
pydantic_core 2.18.2
pydot 1.4.2
pyerfa 2.0.1.4
Pygments 2.17.2
pygrib 2.1.5
pyjsparser 2.7.1
pylibcugraph 24.4.0
pylibcugraphops 24.4.0
pylibraft 24.4.0
pylibwholegraph 24.4.0
pymongo 4.8.0
pynvjitlink 0.1.13
pynvml 11.4.1
pyparsing 3.1.2
pyproj 3.6.1
pyshp 2.3.1
pytest 8.1.1
pytest-flakefinder 1.1.0
pytest-rerunfailures 14.0
pytest-shard 0.1.2
pytest-xdist 3.6.1
python-dateutil 2.9.0.post0
python-hostlist 1.23.0
pytorch-quantization 2.1.2
pytorch-triton 3.0.0+989adb9a2
pytz 2024.1
pyvista 0.43.9
PyYAML 6.0.1
pyzmq 26.0.3
raft-dask 24.4.0
rapids-dask-dependency 24.4.0a0
readline 6.2.4.1
rechunker 0.5.2
redis 5.0.7
referencing 0.35.1
regex 2024.4.28
requests 2.31.0
requests-oauthlib 2.0.0
rich 13.7.1
rmm 24.4.0
rpds-py 0.18.0
rsa 4.9
ruamel.yaml 0.18.6
ruamel.yaml.clib 0.2.8
ruff 0.4.8
s3fs 2024.6.0
safetensors 0.4.3
scikit_build_core 0.9.4
scikit-image 0.23.2
scikit-learn 1.4.2
scipy 1.13.0
scooby 0.10.0
Send2Trash 1.8.3
sentry-sdk 2.5.1
setproctitle 1.3.3
setuptools 68.2.2
shapely 2.0.4
six 1.16.0
smart-open 6.4.0
smmap 5.0.1
sortedcontainers 2.4.0
soundfile 0.12.1
soupsieve 2.5
soxr 0.3.7
spacy 3.7.4
spacy-legacy 3.0.12
spacy-loggers 1.0.5
sphinx_glpi_theme 0.6
srsly 2.4.8
stack-data 0.6.3
sympy 1.12
tabulate 0.9.0
tbb 2021.12.0
tblib 3.0.0
tensorboard 2.9.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorly 0.8.1
tensorly-torch 0.5.0
tensorrt 10.0.1
terminado 0.18.1
texttable 1.7.0
thinc 8.2.3
threadpoolctl 3.5.0
thriftpy2 0.4.20
tifffile 2024.5.22
timm 1.0.3
tinycss2 1.3.0
toml 0.10.2
tomli 2.0.1
toolz 0.12.1
torch 2.4.0a0+07cecf4168.nv24.5
torch_geometric 2.5.3
torch-harmonics 0.6.5
torch-tensorrt 2.4.0a0
torchdata 0.7.1
torchinfo 1.8.0
torchvision 0.19.0a0
tornado 6.4
tqdm 4.66.4
traitlets 5.9.0
transformer-engine 1.6.0+c81733f
treelite 4.1.2
triton 2.3.1
typer 0.9.4
types-dataclasses 0.6.6
typing_extensions 4.11.0
tzdata 2024.1
tzlocal 5.2
ucx-py 0.37.0
urllib3 2.0.7
vtk 9.3.0
wandb 0.17.1
wasabi 1.1.2
wcwidth 0.2.13
weasel 0.3.4
weatherbench2 0.2.0
webencodings 0.5.1
Werkzeug 3.0.2
wheel 0.43.0
wrapt 1.16.0
xarray 2023.7.0
xarray-beam 0.6.3
xdoctest 1.0.2
xgboost 2.0.3
xyzservices 2024.6.0
yarl 1.9.4
zarr 2.17.2
zict 3.0.0
zipp 3.18.1
zstandard 0.22.0

Update, apparently this error does not occur when I use the command line script that inputs the same Data Config which is strange.

python evaluate.py --forecast_path=/pscratch/sd/j/jwillard/healpix_era5/results/nat1d_1deg_e512_w513_lr5em4cos/00/inference/forecasts.zarr
--obs_path=/pscratch/sd/j/jwillard/healpix_era5/data/latlon_1deg_combined_wb.zarr
--output_dir=/pscratch/sd/j/jwillard/FCN_exp/wb2_eval/
--output_file_prefix=test
--input_chunks=init_time=1,lead_time=1
--runner=DirectRunner
--fanout=27
--regions=all
--eval_configs=deterministic
--evaluate_climatology=False
--evaluate_persistence=False
--time_start=2020-01-01
--time_stop=2022-12-31
--pressure_level_suffixes=False
--variables=geopotential,temperature,u_component_of_wind,v_component_of_wind,specific_humidity,2m_temperature,10m_u_component_of_wind,10m_v_component_of_wind,mean_sea_level_pressure
--use_beam=True

Issue resolved,

My configuration file was passing in variables as <class 'omegaconf.listconfig.ListConfig'> and not the normal <class 'list'>, which is why the xarray couldn't read it as a list. I didn't notice because they print() and repr() identically. The issue wasn't "in beam" or "out of beam", just that I was running outside the configuration loading when I was running in Jupyter or within the WB2 cases where it worked. Sorry for the false flag