exported huggingface tokenizer generates different results
patricianing opened this issue · 2 comments
Certain fields in the tokenizer was not checked when exporting with onnxruntime-extension pnp module, causing a mismatch for cls_token and sep_token.
code showing the difference
import onnxruntime
import onnxruntime_extensions
import numpy as np
from onnxruntime_extensions import pnp
from transformers import AutoTokenizer
def map_token_output(input_ids, attention_mask, token_type_ids):
return input_ids.unsqueeze(0), token_type_ids.unsqueeze(0), attention_mask.unsqueeze(0)
model_name = 'sentence-transformers/all-mpnet-base-v2'
output_name = 'all-mpnet-base-v2-aug.onnx'
symbolic = {0: 'batch_size', 1: 'sequence_length'}
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_tokenizer = pnp.PreHuggingFaceBert(hf_tok=tokenizer)
augmented_model = pnp.SequentialProcessingModule(bert_tokenizer, map_token_output)
test_input = ["This is a test sentence"]
augmented_model = pnp.export(augmented_model,
test_input,
opset_version=12,
input_names=['input'],
output_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_path=output_name,
dynamic_axes={'input_ids': symbolic, 'attention_mask': symbolic,
'token_type_ids': symbolic})
session_options = onnxruntime.SessionOptions()
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
session = onnxruntime.InferenceSession(output_name, session_options)
results = session.run([], {"input": test_input})
encoded_input = tokenizer(test_input, padding=True, truncation=True, return_tensors='pt')
np.testing.assert_allclose(encoded_input.get('input_ids'), results[0], rtol=1e-04, atol=1e-05)
Also tried passing the vocab file, no difference
onnx_tokenizer = pnp.PreHuggingFaceBert(vocab_file='./all-mpnet-base-v2/vocab.txt')
I think this could be fixed if the tokenizer constructor was modified here to pull out the following variables:
self.onnx_bert_tokenizer = create_op_function('BertTokenizer', bert_tokenizer,
hf_tok=hf_tok,
sep_token=hf_tok.eos_token,
cls_token=hf_tok.bos_token,
pad_token=hf_tok.pad_token)
As it defaults all those fields here but in the case of all-mpnet-base-v2
those defaults are wrong.