huggingface/transformers

train_new_from_iterator does not properly modify the tokenizer's postprocessor's ids when using a Sequence postprocessor

dmcinerney opened this issue · 1 comments

System Info

  • transformers version: 4.36.1
  • Platform: Linux-5.4.0-1123-aws-fips-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.1
  • Accelerate version: 0.25.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): 2.14.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi, I am encountering an issue when training a new tokenizer based on the 'meta-llama/Meta-Llama-3-8B' tokenizer. In particular, the tokenizer's post_processor ids are not being reset correctly. You can reproduce the bug by running the code below.

from transformers import AutoTokenizer
import json

# Download the llama 3 tokenizer
original_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')

# Create a new tokenizer like the old tokenizer and train it
new_tokenizer = original_tokenizer.train_new_from_iterator(iter(['hello', 'world']), 1000)

# set the pad token on both
original_tokenizer.pad_token_id = original_tokenizer.eos_token_id
new_tokenizer.pad_token_id = new_tokenizer.eos_token_id

# try tokenizing with both
text = ['hello world', 'how are you today?']
batch = original_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
print("Original bos_token_id", original_tokenizer.bos_token_id)
print("Original tokenizer input_ids:")
print(batch.input_ids)
print()
batch = new_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
print("New bos_token_id:", new_tokenizer.bos_token_id)
print("New tokenizer input_ids:")
print(batch.input_ids)

# print out the new tokenizer's postprocessing info to show that the bos token was not changed
print("New tokenizer post processing_info",
      json.dumps(json.loads(new_tokenizer._tokenizer.to_str())['post_processor'], indent=2))

This outputs the following:

Original bos_token_id 128000
Original tokenizer input_ids:
tensor([[128000,  15339,   1917, 128001, 128001, 128001],
        [128000,   5269,    527,    499,   3432,     30]])

New bos_token_id: 0
New tokenizer input_ids:
tensor([[128000,    269,    270,      1,      1,      1,      1,      1,      1],
        [128000,    258,    260,    262,    261,    257,    260,    260,    256]])
New tokenizer post processing_info {
  "type": "Sequence",
  "processors": [
    {
      "type": "ByteLevel",
      "add_prefix_space": true,
      "trim_offsets": false,
      "use_regex": true
    },
    {
      "type": "TemplateProcessing",
      "single": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        }
      ],
      "pair": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        },
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 1
          }
        },
        {
          "Sequence": {
            "id": "B",
            "type_id": 1
          }
        }
      ],
      "special_tokens": {
        "<|begin_of_text|>": {
          "id": "<|begin_of_text|>",
          "ids": [
            128000
          ],
          "tokens": [
            "<|begin_of_text|>"
          ]
        }
      }
    }
  ]
}

The expected output is that the bos token id of 128000 is changed to the new bos token of 0 like in the following:

Original bos_token_id 128000
Original tokenizer input_ids:
tensor([[128000,  15339,   1917, 128001, 128001, 128001],
        [128000,   5269,    527,    499,   3432,     30]])

New bos_token_id: 0
New tokenizer input_ids:
tensor([[0,    269,    270,      1,      1,      1,      1,      1,      1],
        [0,    258,    260,    262,    261,    257,    260,    260,    256]])
New tokenizer post processing_info {
  "type": "Sequence",
  "processors": [
    {
      "type": "ByteLevel",
      "add_prefix_space": true,
      "trim_offsets": false,
      "use_regex": true
    },
    {
      "type": "TemplateProcessing",
      "single": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        }
      ],
      "pair": [
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 0
          }
        },
        {
          "Sequence": {
            "id": "A",
            "type_id": 0
          }
        },
        {
          "SpecialToken": {
            "id": "<|begin_of_text|>",
            "type_id": 1
          }
        },
        {
          "Sequence": {
            "id": "B",
            "type_id": 1
          }
        }
      ],
      "special_tokens": {
        "<|begin_of_text|>": {
          "id": "<|begin_of_text|>",
          "ids": [
            0
          ],
          "tokens": [
            "<|begin_of_text|>"
          ]
        }
      }
    }
  ]
}

I believe this is caused because the train_new_from_iterator function does not handle the case where postprocessors are the type Sequence (as in they contain multiple postprocessors) in the code from that method:

if post_processor is not None:
trained_tokenizer_json = json.loads(tokenizer.to_str())
# Almost done, we just have to adjust the token IDs in the post processor
if "special_tokens" in post_processor:
for key in post_processor["special_tokens"]:
tokens = post_processor["special_tokens"][key]["tokens"]
if special_tokens_map is not None:
tokens = [special_tokens_map.get(token, token) for token in tokens]
post_processor["special_tokens"][key]["tokens"] = tokens
post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
for special_token in ["cls", "sep"]:
if special_token in post_processor:
token, _ = post_processor[special_token]
if special_tokens_map is not None and token in special_tokens_map:
token = special_tokens_map[token]
token_id = tokenizer.token_to_id(token)
post_processor[special_token] = [token, token_id]
trained_tokenizer_json["post_processor"] = post_processor
tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))

Thanks in advance for the help!

Expected behavior

The expected behavior is that the function train_new_from_iterator will properly overwrite the original special token ids in the fast tokenizer's Sequence postprocessor when the special token ids are different in the new tokenizer.