tensorflow/mesh

[Bug Fix] Evaluation and Prediction for Aligned model

agemagician opened this issue · 1 comments

Hello,

Both evaluation and prediction currently not working with the aligned model "Bert Style".

I have fixed this issue by adding a new if statement in "transformer/utils.py":

    elif mode == tf.estimator.ModeKeys.PREDICT:
      inputs = mtf_features["inputs"]
      if predict_fn:
        mtf_samples = predict_fn(
            model=transformer_model,
            features=mtf_features,
            variable_dtype=get_variable_dtype())
      elif isinstance(transformer_model, transformer.Unitransformer) and model_type == 'aligned':
        # pad so that there is enough room for the targets
        inputs = mtf.pad(
            inputs, [0, sequence_length["targets"]], length_dim.name)
        logits, _ = transformer_model.call_simple(
            inputs=inputs, variable_dtype=get_variable_dtype(),
            compute_loss=False,
            mode=tf.estimator.ModeKeys.PREDICT)

        label_c_dim = mtf.Dimension('vocab', 256)
        mtf_samples = mtf.argmax(logits, label_c_dim)

As well as "transformer/transformer.py" needs to be modified :

  def call_simple(self,
                  inputs = None,
                  targets = None,
                  compute_loss = False,
                  mode=tf.estimator.ModeKeys.TRAIN,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  sequence_id=None,
                  subsequence_id=None,
                  position=None,
                  encoder_output=None,
                  encoder_sequence_id=None,
                  encoder_inputs=None,
                  shared_params=None,
                  layer_outputs=None,
                  encoder_layer_outputs=None,
                  num_microbatches=1):

The only thing that I am currently defining manually is "label_c_dim".
@adarob @craffel @nshazeer It will be great if you could merge my code or defining a better solution and find an automatic way to find the vocab size for "label_c_dim".

Please create a PR if you'd like to merge this.