tensorflow/transform

tft.quantiles(reduce_instance_dims=False) errors with columns that are all NaN-valued

cyc opened this issue · 1 comments

cyc commented

TF 2.7.0 and TFT 1.5.0

import pprint
import tempfile

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils


def main():
  def preprocessing_fn(inputs):
    """Preprocess input columns into transformed columns."""
    features = [
        inputs['no_nans'],
        inputs['some_nans'],
        inputs['all_nans']
    ]
    concatenated = tf.stack(features, axis=-1)
    medians = tft.quantiles(x=concatenated, num_buckets=2, epsilon=0.01, reduce_instance_dims=False)[:, 0]
    return {
        'normalized': concatenated / medians
    }

  nan = float('nan')
  data = [
      {'no_nans': 0.0, 'some_nans': 0.0, 'all_nans': nan},
      {'no_nans': 1.0, 'some_nans': nan, 'all_nans': nan},
      {'no_nans': 2.0, 'some_nans': 2.0, 'all_nans': nan},
      {'no_nans': 3.0, 'some_nans': nan, 'all_nans': nan},
      {'no_nans': 4.0, 'some_nans': 4.0, 'all_nans': nan},
      {'no_nans': 5.0, 'some_nans': nan, 'all_nans': nan},
      {'no_nans': 6.0, 'some_nans': 6.0, 'all_nans': nan},
      {'no_nans': 7.0, 'some_nans': nan, 'all_nans': nan},
      {'no_nans': 8.0, 'some_nans': 8.0, 'all_nans': nan},
      {'no_nans': 9.0, 'some_nans': nan, 'all_nans': nan},
  ]

  raw_data_metadata = dataset_metadata.DatasetMetadata(
      schema_utils.schema_from_feature_spec({
          'no_nans': tf.io.FixedLenFeature([], tf.float32),
          'some_nans': tf.io.FixedLenFeature([], tf.float32),
          'all_nans': tf.io.FixedLenFeature([], tf.float32),
      }))

  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    transformed_dataset, transform_fn = (  # pylint: disable=unused-variable
        (data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
            preprocessing_fn))

  transformed_data, transformed_metadata = transformed_dataset  # pylint: disable=unused-variable

  pprint.pprint(transformed_data)

if __name__ == '__main__':
  main()

If you run the code above, you will get an exception like ValueError: cannot reshape array of size 2 into shape (3,1). However, if you comment out the "all_nans" feature (line 17) you will get the expected output:

[{'normalized': array([0., 0.], dtype=float32)},
 {'normalized': array([0.2, nan], dtype=float32)},
 {'normalized': array([0.4, 0.5], dtype=float32)},
 {'normalized': array([0.6, nan], dtype=float32)},
 {'normalized': array([0.8, 1. ], dtype=float32)},
 {'normalized': array([ 1., nan], dtype=float32)},
 {'normalized': array([1.2, 1.5], dtype=float32)},
 {'normalized': array([1.4, nan], dtype=float32)},
 {'normalized': array([1.6, 2. ], dtype=float32)},
 {'normalized': array([1.8, nan], dtype=float32)}]

Also confirmed that the problem does not exist with tft.quantiles(reduce_instance_dims=True) and passing in all NaN values.

Note: this may be related to commit ea54bc2 which made quantiles ignore nans.
We may need to update the code to put in a default for the edge case where the dataset feature is all nans.