tft.quantiles(reduce_instance_dims=False) errors with columns that are all NaN-valued
cyc opened this issue · 1 comments
cyc commented
TF 2.7.0 and TFT 1.5.0
import pprint
import tempfile
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils
def main():
def preprocessing_fn(inputs):
"""Preprocess input columns into transformed columns."""
features = [
inputs['no_nans'],
inputs['some_nans'],
inputs['all_nans']
]
concatenated = tf.stack(features, axis=-1)
medians = tft.quantiles(x=concatenated, num_buckets=2, epsilon=0.01, reduce_instance_dims=False)[:, 0]
return {
'normalized': concatenated / medians
}
nan = float('nan')
data = [
{'no_nans': 0.0, 'some_nans': 0.0, 'all_nans': nan},
{'no_nans': 1.0, 'some_nans': nan, 'all_nans': nan},
{'no_nans': 2.0, 'some_nans': 2.0, 'all_nans': nan},
{'no_nans': 3.0, 'some_nans': nan, 'all_nans': nan},
{'no_nans': 4.0, 'some_nans': 4.0, 'all_nans': nan},
{'no_nans': 5.0, 'some_nans': nan, 'all_nans': nan},
{'no_nans': 6.0, 'some_nans': 6.0, 'all_nans': nan},
{'no_nans': 7.0, 'some_nans': nan, 'all_nans': nan},
{'no_nans': 8.0, 'some_nans': 8.0, 'all_nans': nan},
{'no_nans': 9.0, 'some_nans': nan, 'all_nans': nan},
]
raw_data_metadata = dataset_metadata.DatasetMetadata(
schema_utils.schema_from_feature_spec({
'no_nans': tf.io.FixedLenFeature([], tf.float32),
'some_nans': tf.io.FixedLenFeature([], tf.float32),
'all_nans': tf.io.FixedLenFeature([], tf.float32),
}))
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
transformed_dataset, transform_fn = ( # pylint: disable=unused-variable
(data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn))
transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable
pprint.pprint(transformed_data)
if __name__ == '__main__':
main()
If you run the code above, you will get an exception like ValueError: cannot reshape array of size 2 into shape (3,1)
. However, if you comment out the "all_nans" feature (line 17) you will get the expected output:
[{'normalized': array([0., 0.], dtype=float32)},
{'normalized': array([0.2, nan], dtype=float32)},
{'normalized': array([0.4, 0.5], dtype=float32)},
{'normalized': array([0.6, nan], dtype=float32)},
{'normalized': array([0.8, 1. ], dtype=float32)},
{'normalized': array([ 1., nan], dtype=float32)},
{'normalized': array([1.2, 1.5], dtype=float32)},
{'normalized': array([1.4, nan], dtype=float32)},
{'normalized': array([1.6, 2. ], dtype=float32)},
{'normalized': array([1.8, nan], dtype=float32)}]
Also confirmed that the problem does not exist with tft.quantiles(reduce_instance_dims=True)
and passing in all NaN values.