microsoft/onnxruntime-extensions

exported huggingface tokenizer generates different results

patricianing opened this issue · 2 comments

Certain fields in the tokenizer was not checked when exporting with onnxruntime-extension pnp module, causing a mismatch for cls_token and sep_token.

code showing the difference

import onnxruntime
import onnxruntime_extensions
import numpy as np

from onnxruntime_extensions import pnp
from transformers import AutoTokenizer

def map_token_output(input_ids, attention_mask, token_type_ids):
return input_ids.unsqueeze(0), token_type_ids.unsqueeze(0), attention_mask.unsqueeze(0)

model_name = 'sentence-transformers/all-mpnet-base-v2'
output_name = 'all-mpnet-base-v2-aug.onnx'

symbolic = {0: 'batch_size', 1: 'sequence_length'}

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_tokenizer = pnp.PreHuggingFaceBert(hf_tok=tokenizer)

augmented_model = pnp.SequentialProcessingModule(bert_tokenizer, map_token_output)

test_input = ["This is a test sentence"]

augmented_model = pnp.export(augmented_model,
test_input,
opset_version=12,
input_names=['input'],
output_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_path=output_name,
dynamic_axes={'input_ids': symbolic, 'attention_mask': symbolic,
'token_type_ids': symbolic})

session_options = onnxruntime.SessionOptions()
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
session = onnxruntime.InferenceSession(output_name, session_options)
results = session.run([], {"input": test_input})

encoded_input = tokenizer(test_input, padding=True, truncation=True, return_tensors='pt')
np.testing.assert_allclose(encoded_input.get('input_ids'), results[0], rtol=1e-04, atol=1e-05)

Also tried passing the vocab file, no difference
onnx_tokenizer = pnp.PreHuggingFaceBert(vocab_file='./all-mpnet-base-v2/vocab.txt')

I think this could be fixed if the tokenizer constructor was modified here to pull out the following variables:

            self.onnx_bert_tokenizer = create_op_function('BertTokenizer', bert_tokenizer,
                                                          hf_tok=hf_tok,
                                                          sep_token=hf_tok.eos_token,
                                                          cls_token=hf_tok.bos_token,
                                                          pad_token=hf_tok.pad_token)

As it defaults all those fields here but in the case of all-mpnet-base-v2 those defaults are wrong.