mllam/neural-lam

Designing a `Datastore` class (for handling different storage backends for datasets)

Opened this issue · 2 comments

I am in the process of reading through @sadamov's PR #54 on using zarr-based datasets in neural-lam and I am going to use this issue to write down some notes. Everyone is free to read-along, but this will only over time become a coherent piece of information, so it is probably best to wait until I comment directly on #54.

Uses of current neural_lam.config.Config attributes and methods outside the class itself:

$> grep -r 'config\.' *.py       
calculate_statistics.py:        default="neural_lam/data_config.yaml",
calculate_statistics.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
calculate_statistics.py:    data_config = config.Config.from_file(args.data_config)
calculate_statistics.py:    state_data = data_config.process_dataset("state", split="train")
calculate_statistics.py:    forcing_data = data_config.process_dataset(
create_boundary_mask.py:        default="neural_lam/data_config.yaml",
create_boundary_mask.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
create_boundary_mask.py:    data_config = config.Config.from_file(args.data_config)
create_boundary_mask.py:    mask = np.zeros(list(data_config.grid_shape_state.values.values()))
create_forcings.py:        "--data_config", type=str, default="neural_lam/data_config.yaml"
create_forcings.py:    data_config = config.Config.from_file(args.data_config)
create_forcings.py:    dataset = data_config.open_zarrs("state")
create_mesh.py:        default="neural_lam/data_config.yaml",
create_mesh.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
create_mesh.py:    data_config = config.Config.from_file(args.data_config)
create_mesh.py:    xy = data_config.get_xy("static")  # (2, N_y, N_x)
plot_graph.py:        default="neural_lam/data_config.yaml",
plot_graph.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
plot_graph.py:    data_config = config.Config.from_file(args.data_config)
plot_graph.py:    xy = data_config.get_xy("state")  # (2, N_y, N_x)
train_model.py:        default="neural_lam/data_config.yaml",
train_model.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
$> grep -r 'config\.' neural_lam 
neural_lam/models/ar_model.py:        self.data_config = config.Config.from_file(args.data_config)
neural_lam/models/ar_model.py:        static = self.data_config.process_dataset("static")
neural_lam/models/ar_model.py:        state_stats = self.data_config.load_normalization_stats(
neural_lam/models/ar_model.py:        self.grid_output_dim = self.data_config.num_data_vars("state")
neural_lam/models/ar_model.py:            self.grid_output_dim = 2 * self.data_config.num_data_vars("state")
neural_lam/models/ar_model.py:            self.grid_output_dim = self.data_config.num_data_vars("state")
neural_lam/models/ar_model.py:            + self.data_config.num_data_vars("forcing")
neural_lam/models/ar_model.py:            * self.data_config.forcing.window
neural_lam/models/ar_model.py:        boundary_mask = self.data_config.load_boundary_mask()
neural_lam/models/ar_model.py:        self.step_length = self.data_config.step_length
neural_lam/models/ar_model.py:                            self.data_config.vars_names("state"),
neural_lam/models/ar_model.py:                            self.data_config.vars_units("state"),
neural_lam/models/ar_model.py:                            self.data_config.vars_names("state"), var_figs
neural_lam/models/ar_model.py:                var = self.data_config.vars_names("state")[var_i]
neural_lam/config.py:        proj_params = proj_config.get("kwargs", {})
neural_lam/config.py:                "vars based on zarr config...\033[0m"
neural_lam/config.py:                    "vars based on zarr config.\033[0m"
neural_lam/config.py:                    "vars based on zarr config.\033[0m"
grep: neural_lam/__pycache__/config.cpython-310.pyc: binary file matches
neural_lam/vis.py:            data_config.vars_names("state"), data_config.vars_units("state")
neural_lam/vis.py:    extent = data_config.get_xy_extent("state")
neural_lam/vis.py:        list(data_config.grid_shape_state.values.values())
neural_lam/vis.py:        subplot_kw={"projection": data_config.coords_projection},
neural_lam/vis.py:            data.reshape(list(data_config.grid_shape_state.values.values()))
neural_lam/vis.py:    extent = data_config.get_xy_extent("state")
neural_lam/vis.py:        list(data_config.grid_shape_state.values.values())
neural_lam/vis.py:        subplot_kw={"projection": data_config.coords_projection},
neural_lam/vis.py:        error.reshape(list(data_config.grid_shape_state.values.values()))
neural_lam/weather_dataset.py:        data_config="neural_lam/data_config.yaml",
neural_lam/weather_dataset.py:        self.data_config = config.Config.from_file(data_config)
neural_lam/weather_dataset.py:        self.state = self.data_config.process_dataset("state", self.split)
neural_lam/weather_dataset.py:        self.forcing = self.data_config.process_dataset("forcing", self.split)
neural_lam/weather_dataset.py:            state_stats = self.data_config.load_normalization_stats(
neural_lam/weather_dataset.py:                forcing_stats = self.data_config.load_normalization_stats(

Methods:

  • from_file(filepath), called in:
    • calculate_statistics.main with filepath=args.data_config
    • create_boundary_mask.main with filepath=args.data_config
    • create_forcings.main with filepath=args.data_config
    • create_mesh.main with filepath=args.data_config
    • neural_lam.models.ar_model.ARModel.__init__ with filepath=args.data_config
    • neural_lam.weather_dataset.WeatherDataset.__init__ with filepath=args.data_config
    • plot_graph.main with filepath=args.data_config
  • process_dataset(category, split, apply_windowing=True), called in
    • neural_lam.models.ar_model.ARModel.__init__ with category="static"
    • neural_lam.models.weather_dataset.WeatherDataset with category="state" and category="forcing"
  • open_zarrs(category), called in:
    • create_forcings.main with category="state"
    • also called 7 times with neural_lam.config.Config class itself
  • get_xy(category, stacked=True), called in:
    • create_mesh.main with category="static"
    • neural_lam.config.Config.get_xy_extent it self with stacked=False
    • plot_graph with category="state"
  • load_normalization_stats(category, datatype="torch"), called in:
    • neural_lam.models.ARModel.__init__ with category="state", datatype="torch"
    • neural_lam.weather_dataset.WeatherDataset.__init__ with category="state", datatype="torch" and category="forcing", datatype="torch"
  • load_boundary_mask(), called in:
    • neural_lam.config.Config.__init__
  • num_data_vars(category), called in:
    • neural_lam.models.ARModel.__init__ with category="state" three times and with category="forcing" once
  • vars_names(category) called in
    • neural_lam.config.Config._select_stats_by_category(combined_stats, category) twice
    • neural_lam.models.ar_model.ARModel.plot_examples with category="state" twice
    • neural_lam.models.ar_model.ARModel.create_metric_log_dict with category="state"
  • vars_units(category), called in:
    • neural_lam.models.ar_model.ARModel.plot_examples with category="state"

Attributes:

  • grid_shape_state, defined in neural_lam/data_config.yml as grid_shape_state=dict(y=589, y=789), referenced in
    • neural_lam.vis.plot_prediction, neural_lam.vis.plot_spatial_error and create_boundary_mask.main
  • forcing accesses entire sub-tree dict structure forcing in neural_lam/data_config.yml, referenced in
    • neural_lam.models.ar_model.ARModel.__init__
  • step_length computed time coordinate values of state "category", referenced in
    • neural_lam.models.ar_model.ARModel.__init__
  • coords_projection, defined in root of neural_lam/data_config.yml, referenced in
    • neural_lam.vis.plot_prediction and neural_lam.vis.plot_spatial_error

@sadamov I made a dataclasses based class structure which matches the config yaml file you've made: https://github.com/leifdenby/neural-lam/blob/mllam-dataloader/neural_lam/multizarr_datastore_config.py

I will explain this in more detail tomorrow, but the idea is that by using dataclasses the config structure is stored in code, it is validated, and the nested structure (enabling the dot-syntax nesting that you achieve with __getattr__) is provided through the nested dataclass objects. It doesn't match everything because I don't quite understand the whole config. The serialization to/from yaml is handled by dataclass-wizard
Screenshot 2024-06-12 at 21 23 06