tensorflow/adanet

Allow one to forward features to predictions

jankrynauw opened this issue · 3 comments

We would like to forward a particular 'key' column which is part of the features to appear alongside the predictions - this is to be able to identify to which set of features a particular prediction belongs to. Here is an example of predictions output using the tensorflow.contrib.estimator.multi_class_head:

{"classes": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
 "scores": [0.06819603592157364, 0.0864366963505745, 0.12838752567768097, 0.046013250946998596, 0.03129083290696144, 0.1518409103155136, 0.1248951405286789, 0.15043732523918152, 0.0821763351559639, 0.13032598793506622]}

We would therefore like to add a key attribute to this prediction.

estimator = tf.contrib.estimator.forward_features(estimator, ['key'])

gives the following error:

The adanet.Estimator's model_fn should not be called directly in TRAIN mode, because its behavior is undefined outside the context of its train method.

The current workaround is to subclass the head

@jankrynauw Thank you for the feature request. This looks like it could be implemented in a way similar to adanet.Estimator(metric_fn=...) which we added in 53f5f5b.

Unfortunately, we don't have much time to add this ourselves, but if anyone has cycles to a PR, they are very welcome.

@jankrynauw Can you send me the code you were working on so that I can test it and work on it.

See the predict section of the _create_tpu_estimator_spec method, as well as the predict_feature_keys=None parameter in the __init__ section:

# Implement a custom Head to control the prediction output. Most of this code is directly from
# `_MultiClassHeadWithSoftmaxCrossEntropyLoss` in `tensorflow_estimator.python.estimator.canned.head`
class MultiClassHead(_MultiClassHeadWithSoftmaxCrossEntropyLoss):
    def __init__(self,
                 n_classes,
                 predict_feature_keys=None,
                 weight_column=None,
                 label_vocabulary=None,
                 loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
                 loss_fn=None,
                 name=None):
        """Creates a '_Head' for multi class classification.

        The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many applications, the shape is
        `[batch_size, n_classes]`.

        `labels` must be a dense `Tensor` with shape matching `logits`, namely
        `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string `Tensor` with values from the
        vocabulary. If `label_vocabulary` is not given, `labels` must be an integer `Tensor` with values specifying the
        class index.

        If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.

        The loss is the weighted sum over the input dimensions. Namely, if the input labels have shape
        `[batch_size, 1]`, the loss is the weighted sum over `batch_size`.

        Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments
        and returns unreduced loss with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with shape
        `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to
        `loss_fn`.

        Args:
            n_classes: Number of classes, must be greater than 2 (for 2 classes,
                use `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
            predict_feature_keys: A `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict`
                is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list`
                of strings, all the given `keys` are forwarded.
            weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature
                column representing weights. It is used to down weight or boost examples during training. It will be
                multiplied by the loss of the example.
            label_vocabulary: A list or tuple of strings representing possible label values. If it is not given, that
                means labels are already encoded as an integer within [0, n_classes). If given, labels must be of string
                type and have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is
                not provided but labels are strings.
            loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over
                batch. Defaults to `SUM_OVER_BATCH_SIZE`.
            loss_fn: Optional loss function.
            name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as
                `name_scope` when creating ops.
        Returns:
            An instance of `_Head` for multi class classification.
        """
        self._predict_feature_keys = predict_feature_keys  # Customisation done by alis.
        super(MultiClassHead, self).__init__(n_classes=n_classes,
                                             weight_column=weight_column,
                                             label_vocabulary=label_vocabulary,
                                             loss_reduction=loss_reduction,
                                             loss_fn=loss_fn,
                                             name=name)

    def _create_tpu_estimator_spec(self,
                                   features,
                                   mode,
                                   logits,
                                   labels=None,
                                   optimizer=None,
                                   train_op_fn=None,
                                   regularization_losses=None):
        """Returns a `model_fn._TPUEstimatorSpec`.

        Args:
            features: Input `dict` of `Tensor` or `SparseTensor` objects.
            mode: Estimator's `ModeKeys`.
            logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`. For many applications, the shape is
                `[batch_size, logits_dimension]`.
            labels: Labels integer or string `Tensor` with shape matching `logits`, namely `[D0, D1, ... DN, 1]` or
                `[D0, D1, ... DN]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`.
            optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. Namely, sets `train_op =
                optimizer.minimize(loss, global_step)`, which updates variables and increments `global_step`.
            train_op_fn: Function that takes a scalar loss `Tensor` and returns `train_op`. Used if `optimizer` is `None`.
            regularization_losses: A list of additional scalar losses to be added to the training loss, such as
                regularization losses. These losses are usually expressed as a batch average, so for best results users
                need to set `loss_reduction=SUM_OVER_BATCH_SIZE` or `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when
                creating the head to avoid scaling errors.
        Returns:
            A `model_fn._TPUEstimatorSpec` instance.
        """
        with ops.name_scope(self._name, 'head'):
            logits = _check_logits_final_dim(logits, self.logits_dimension)

            with ops.name_scope(None, 'predictions', (logits,)):
                # Get the predicted class IDs and names
                class_ids = math_ops.argmax(logits, axis=-1, name=PredictionKeys.CLASS_IDS)
                class_ids = array_ops.expand_dims(class_ids, axis=-1)
                if self._label_vocabulary:
                    table = lookup_ops.index_to_string_table_from_tensor(vocabulary_list=self._label_vocabulary,
                                                                         name='class_string_lookup')
                    classes = table.lookup(class_ids)
                else:
                    classes = string_ops.as_string(class_ids, name='str_classes')

                # Compute the predicted probabilities
                probabilities = nn.softmax(logits, name=PredictionKeys.PROBABILITIES)

                # Compute the predicted score we use for ranking
                predicted_class = math_ops.argmax(logits, axis=-1) + 1
                predicted_class_probability = math_ops.reduce_max(probabilities, axis=-1)
                predicted = tf.expand_dims(tf.cast(predicted_class, tf.float32) - predicted_class_probability, axis=-1)

                predictions = {
                    PredictionKeys.CLASS_IDS: class_ids,
                    PredictionKeys.CLASSES: classes,
                    PredictionKeys.LOGITS: logits,
                    PredictionKeys.PREDICTED: predicted,
                    PredictionKeys.PROBABILITIES: probabilities,
                }

            # Predict
            if mode == model_fn.ModeKeys.PREDICT:

                # CUSTOMISATION DONE HERE.  Attach the feature keys to the predictions
                for key in self._predict_feature_keys:
                    predictions[key] = features[key]

                classifier_output = _classification_output(scores=probabilities,
                                                           n_classes=self._n_classes,
                                                           label_vocabulary=self._label_vocabulary)
                return model_fn._TPUEstimatorSpec(mode=model_fn.ModeKeys.PREDICT,
                                                  predictions=predictions,
                                                  export_outputs={
                                                      _DEFAULT_SERVING_KEY: export_output.PredictOutput(predictions),
                                                      _CLASSIFY_SERVING_KEY: classifier_output,
                                                      _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
                                                  })

            # Compute loss
            training_loss, unreduced_loss, weights, label_ids = self.create_loss(features=features,
                                                                                 mode=mode,
                                                                                 logits=logits,
                                                                                 labels=labels)
            if regularization_losses:
                regularization_loss = math_ops.add_n(regularization_losses)
                regularized_training_loss = math_ops.add_n([training_loss, regularization_loss])
            else:
                regularization_loss = None
                regularized_training_loss = training_loss

            # Eval
            if mode == model_fn.ModeKeys.EVAL:
                return model_fn._TPUEstimatorSpec(mode=model_fn.ModeKeys.EVAL,
                                                  predictions=predictions,
                                                  loss=regularized_training_loss,
                                                  eval_metrics=_create_eval_metrics_tuple(self._eval_metric_ops, {
                                                      'labels': label_ids,
                                                      'class_ids': class_ids,
                                                      'weights': weights,
                                                      'unreduced_loss': unreduced_loss,
                                                      'regularization_loss': regularization_loss
                                                  }))

            # Train
            if optimizer is not None:
                if train_op_fn is not None:
                    raise ValueError('train_op_fn and optimizer cannot both be set.')
                train_op = optimizer.minimize(regularized_training_loss, global_step=training_util.get_global_step())
            elif train_op_fn is not None:
                train_op = train_op_fn(regularized_training_loss)
            else:
                raise ValueError('train_op_fn and optimizer cannot both be None.')
            train_op = _append_update_ops(train_op)

            # Only summarize mean_loss for SUM reduction to preserve backwards compatibility. Otherwise skip it to
            # avoid unnecessary computation.
            if self._loss_reduction == losses.Reduction.SUM:
                example_weight_sum = math_ops.reduce_sum(weights * array_ops.ones_like(unreduced_loss))
                mean_loss = training_loss / example_weight_sum
            else:
                mean_loss = None

        with ops.name_scope(''):
            keys = metric_keys.MetricKeys
            summary.scalar(_summary_key(self._name, keys.LOSS), regularized_training_loss)
            if mean_loss is not None:
                summary.scalar(_summary_key(self._name, keys.LOSS_MEAN), mean_loss)
            if regularization_loss is not None:
                summary.scalar(_summary_key(self._name, keys.LOSS_REGULARIZATION), regularization_loss)

        return model_fn._TPUEstimatorSpec(mode=model_fn.ModeKeys.TRAIN,
                                          predictions=predictions,
                                          loss=regularized_training_loss,
                                          train_op=train_op)