NREL/sup3r

Temporal slicing in Validation data

malihass opened this issue · 2 comments

Bug Description
When using temporal slicing (None, None, x) where x>1, some tests fail. It seems to be due to tupleIndex in ValidationData

Full Traceback

Traceback (most recent call last):
  File "test_train_gan_tslice.py", line 69, in <module>
    test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 3))
  File "test_train_gan_tslice.py", line 63, in test_train_st_weight_update
    adaptive_update_fraction=0.05)
  File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/models/base.py", line 1140, in train
    loss_details)
  File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/models/base.py", line 901, in calc_val_loss
    for val_batch in batch_handler.val_data:
  File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/preprocessing/batch_handling.py", line 316, in __next__
    val_index['tuple_index']]
ValueError: could not broadcast input array from shape (18,18,14,3) into shape (18,18,24,3)

Code Sample

import os
import json
import numpy as np
import pytest
import tempfile
import tensorflow as tf
from tensorflow.python.framework.errors_impl import InvalidArgumentError

from rex import init_logger

from sup3r import TEST_DATA_DIR
from sup3r import CONFIG_DIR
from sup3r.models import Sup3rGan
from sup3r.models.data_centric import Sup3rGanDC, Sup3rGanSpatialDC
from sup3r.preprocessing.data_handling import (DataHandlerH5,
                                               DataHandlerDCforH5)
from sup3r.preprocessing.batch_handling import (BatchHandler,
                                                BatchHandlerDC,
                                                SpatialBatchHandler,
                                                BatchHandlerSpatialDC)
from sup3r.utilities.loss_metrics import MmdMseLoss


FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
FEATURES = ['U_100m', 'V_100m', 'BVF2_200m']


def test_train_st_weight_update(n_epoch=5, log=False, temporal_slice=slice(None, None, 1)):
    """Test basic spatiotemporal model training with discriminators and
    adversarial loss updating."""
    if log:
        init_logger('sup3r', log_level='DEBUG')

    fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
    fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')

    Sup3rGan.seed()
    model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4,
                     learning_rate_disc=3e-4)

    handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD,
                            shape=(20, 20),
                            sample_shape=(18, 18, 24),
                            temporal_slice=temporal_slice,
                            val_split=0.005,
                            max_workers=1)

    batch_handler = BatchHandler([handler], batch_size=4,
                                 s_enhance=3, t_enhance=4,
                                 n_batches=4)

    adaptive_update_bounds = (0.9, 0.99)
    with tempfile.TemporaryDirectory() as td:
        model.train(batch_handler, n_epoch=n_epoch,
                    weight_gen_advers=1e-6,
                    train_gen=True, train_disc=True,
                    checkpoint_int=10,
                    out_dir=os.path.join(td, 'test_{epoch}'),
                    adaptive_update_bounds=adaptive_update_bounds,
                    adaptive_update_fraction=0.05)

if __name__ == "__main__":
    print("\n\n DOING temporal_slice=(None, None, 1) \n\n")
    test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 1))
    print("\n\n DOING temporal_slice=(None, None, 3) \n\n")
    test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 3))

To Reproduce
Steps to reproduce the problem behavior

  1. Copy code sample to tests/
  2. Execute the python script

Expected behavior
Any temporal slicing should work (in the limit of the dataset size)

bnb32 commented

It looks like since val_split=0.005 and the step=3 that only leaves 14 time steps in the validation data (8784 * 0.005 // 3), and the sample_shape is requesting 24 time steps. Should definitely be a clearer error message.

Added warning here -

def _val_split_check(self):

Ah you are right, this was the issue, I let you close this when that PR with val_split_check is merged.
Thanks!