Temporal slicing in Validation data
malihass opened this issue · 2 comments
Bug Description
When using temporal slicing (None, None, x)
where x>1
, some tests fail. It seems to be due to tupleIndex
in ValidationData
Full Traceback
Traceback (most recent call last):
File "test_train_gan_tslice.py", line 69, in <module>
test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 3))
File "test_train_gan_tslice.py", line 63, in test_train_st_weight_update
adaptive_update_fraction=0.05)
File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/models/base.py", line 1140, in train
loss_details)
File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/models/base.py", line 901, in calc_val_loss
for val_batch in batch_handler.val_data:
File "/Users/mhassana/Desktop/GitHub/sup3r_nov24_issue/sup3r/preprocessing/batch_handling.py", line 316, in __next__
val_index['tuple_index']]
ValueError: could not broadcast input array from shape (18,18,14,3) into shape (18,18,24,3)
Code Sample
import os
import json
import numpy as np
import pytest
import tempfile
import tensorflow as tf
from tensorflow.python.framework.errors_impl import InvalidArgumentError
from rex import init_logger
from sup3r import TEST_DATA_DIR
from sup3r import CONFIG_DIR
from sup3r.models import Sup3rGan
from sup3r.models.data_centric import Sup3rGanDC, Sup3rGanSpatialDC
from sup3r.preprocessing.data_handling import (DataHandlerH5,
DataHandlerDCforH5)
from sup3r.preprocessing.batch_handling import (BatchHandler,
BatchHandlerDC,
SpatialBatchHandler,
BatchHandlerSpatialDC)
from sup3r.utilities.loss_metrics import MmdMseLoss
FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
FEATURES = ['U_100m', 'V_100m', 'BVF2_200m']
def test_train_st_weight_update(n_epoch=5, log=False, temporal_slice=slice(None, None, 1)):
"""Test basic spatiotemporal model training with discriminators and
adversarial loss updating."""
if log:
init_logger('sup3r', log_level='DEBUG')
fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')
Sup3rGan.seed()
model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4,
learning_rate_disc=3e-4)
handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD,
shape=(20, 20),
sample_shape=(18, 18, 24),
temporal_slice=temporal_slice,
val_split=0.005,
max_workers=1)
batch_handler = BatchHandler([handler], batch_size=4,
s_enhance=3, t_enhance=4,
n_batches=4)
adaptive_update_bounds = (0.9, 0.99)
with tempfile.TemporaryDirectory() as td:
model.train(batch_handler, n_epoch=n_epoch,
weight_gen_advers=1e-6,
train_gen=True, train_disc=True,
checkpoint_int=10,
out_dir=os.path.join(td, 'test_{epoch}'),
adaptive_update_bounds=adaptive_update_bounds,
adaptive_update_fraction=0.05)
if __name__ == "__main__":
print("\n\n DOING temporal_slice=(None, None, 1) \n\n")
test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 1))
print("\n\n DOING temporal_slice=(None, None, 3) \n\n")
test_train_st_weight_update(log=True, n_epoch=1, temporal_slice=(None, None, 3))
To Reproduce
Steps to reproduce the problem behavior
- Copy code sample to
tests/
- Execute the python script
Expected behavior
Any temporal slicing should work (in the limit of the dataset size)
It looks like since val_split=0.005 and the step=3 that only leaves 14 time steps in the validation data (8784 * 0.005 // 3), and the sample_shape is requesting 24 time steps. Should definitely be a clearer error message.
Added warning here -
sup3r/sup3r/preprocessing/data_handling.py
Line 708 in 117479a
Ah you are right, this was the issue, I let you close this when that PR with val_split_check
is merged.
Thanks!