How to build a custom tokenizer on top of a exsiting Llama 3.2 tokenizer?
yakhyo opened this issue · 7 comments
Hi,
I was trying to create a custom tokenizer for a different language which is not included in llama 3.2 tokenizer.
I could not find exactly what tokenizer I can use from hf which is exact alternative to Llama's tokenizer link, so that I will be able to train a new tokenizer.
Currently I am using following code to train a tokenizer, but final example does not match with the one Llama 3.2 has.
I would be nice if anyone could share their experience of adapting a Llama model to a new language.
import json
import argparse
from datasets import load_dataset, concatenate_datasets
from tokenizers import SentencePieceBPETokenizer
from transformers import LlamaTokenizerFast, AutoTokenizer
from tqdm import tqdm
from typing import List
hf_datasets = ["yakhyo/uz-wiki", "yakhyo/uz-news", "agentlans/high-quality-english-sentences"]
def normalize_text(text: str) -> str:
"""
Normalize Uzbek characters, replacing variations of o‘, o', o`, and ’ (curved apostrophe).
"""
return text.replace("‘", "'").replace("`", "'").replace("’", "'").replace("()", "")
def prepare_datasets(datasets_list: List[str]):
all_data = []
for dataset_name in datasets_list:
try:
data = load_dataset(dataset_name)
for split in ["train", "test", "validation"]:
try:
all_data.append(data[split])
except KeyError:
pass
except:
print(f"dataset: `{dataset_name}` not found, skipping...")
concat_data = []
for data in tqdm(all_data):
data = data.map(lambda example: {"text": normalize_text(example["text"])})
data = data.remove_columns([col for col in data.column_names if col != "text"])
concat_data.append(data)
return concatenate_datasets(concat_data)
def main(args):
dataset = prepare_datasets(hf_datasets)
# select num_samples from the dataset
dataset = dataset.shuffle(seed=42).select(range(len(dataset)))
# Create a SentencePieceBPETokenizer
tokenizer = SentencePieceBPETokenizer(
replacement="Ġ"
)
# Train the SentencePieceBPETokenizer on the dataset
tokenizer.train_from_iterator(
iterator=dataset['text'],
vocab_size=args.vocab_size,
show_progress=True,
special_tokens=[
"<unk>",
"<s>",
"</s>",
"<pad>"
],
)
# Save the tokenizer
tokenizer.save("new-sentencepiece-tokenizer.json", pretty=True)
# Load reference tokenizer
if args.reference_tokenizer is not None:
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer)
reference_tokenizer.save_pretrained("reference-tokenizer")
else:
raise ValueError(
"No tokenizer name provided or no hub token provided. Try using --reference_tokenizer 'meta-llama/Llama-2-7b-hf'")
# Read and dump the json file for the new tokenizer and the reference tokenizer
with open("new-sentencepiece-tokenizer.json") as f:
new_llama_tokenizer_json = json.load(f)
with open("reference-tokenizer/tokenizer.json") as f:
reference_tokenizer_json = json.load(f)
# Add the reference tokenizer's config to the new tokenizer's config
new_llama_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"]
new_llama_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"]
new_llama_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"]
new_llama_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"]
new_llama_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk']
new_llama_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback']
# Dump the new tokenizer's config
with open("new-sentencepiece-tokenizer.json", "w") as f:
json.dump(new_llama_tokenizer_json, f, indent=2, ensure_ascii=False)
# Load the new tokenizer as a LlamaTokenizerFast
new_llama_tokenizer = LlamaTokenizerFast(
tokenizer_file="new-sentencepiece-tokenizer.json",
unk_token="<unk>",
unk_token_id=0,
bos_token="<s>",
bos_token_id=1,
eos_token="</s>",
eos_token_id=2,
pad_token="<pad>",
pad_token_id=3,
padding_side="right",
)
# Save the new tokenizer
new_llama_tokenizer.save_pretrained("new-llama-tokenizer")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama Tokenizer using SentencePieceBPE")
parser.add_argument(
"--reference_tokenizer",
type=str,
default=None,
help="The name of the reference tokenizer to use"
)
parser.add_argument(
"--vocab_size",
type=int,
default=None,
help="Vocabulary size to use for the tokenizer"
)
args = parser.parse_args()
main(args)
# How to run:
# python bpe_tokenizer.py --reference_tokenizer "meta-llama/Llama-3.2-3b" --vocab_size 128000
Seems a bit related to huggingface/transformers#27583
@ArthurZucker , thank you for the reference. it really helped me for adding existing tokens. now I have left one problem reproducing llama 3.2 tokenizer (i hope).
I checked Llama 3.2 tokenizer and it does not have unk_token
. However when I train a tokenizer it always adds special tokens for eos, bos, unk. I could somehow solve the eos and bos but could not solve how to ignore unk token, which is being added by default during training a tokenizer.
def main(args):
dataset = prepare_datasets(hf_datasets)
num_reserved_special_tokens = 256
# select num_samples from the dataset
dataset = dataset.shuffle(seed=42).select(range(50000))
# Create a SentencePieceBPETokenizer
tokenizer = SentencePieceBPETokenizer(replacement="Ġ")
# Train the SentencePieceBPETokenizer on the dataset
tokenizer.train_from_iterator(
iterator=dataset['text'],
vocab_size=args.vocab_size,
show_progress=True
)
# Save the tokenizer
tokenizer.save("new-llama-tokenizer.json", pretty=True)
# Load reference tokenizer
if args.reference_tokenizer is not None:
reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer)
reference_tokenizer.save_pretrained("reference-tokenizer")
else:
raise ValueError(
"No tokenizer name provided or no hub token provided. \
Try using --reference_tokenizer 'meta-llama/Llama-3.2-3b'"
)
# Read and dump the json file for the new tokenizer and the reference tokenizer
with open("new-llama-tokenizer.json", "r") as f:
new_llama_tokenizer_json = json.load(f)
with open("reference-tokenizer/tokenizer.json", "r") as f:
reference_tokenizer_json = json.load(f)
# Add the reference tokenizer's config to the new tokenizer's config
new_llama_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"]
new_llama_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"]
new_llama_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"]
new_llama_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"]
new_llama_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk']
new_llama_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback']
# Dump the new tokenizer's config
with open("new-llama-tokenizer.json", "w") as f:
json.dump(new_llama_tokenizer_json, f, indent=2, ensure_ascii=False)
# LlamaTokenizerFast or PreTrainedTokenizerFast
new_llama_tokenizer = PreTrainedTokenizerFast(
tokenizer_file="new-llama-tokenizer.json",
padding_side="right",
model_max_length=131072,
bos_token=AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
eos_token=AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)
)
added_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)
]
# Create AddedToken objects for each reserved token with properties similar to your expected output
added_tokens_object = [
AddedToken(token, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)
for token in added_tokens
]
# Add these reserved tokens to the tokenizer
new_llama_tokenizer.add_tokens(added_tokens_object)
# Save the new tokenizer
new_llama_tokenizer.save_pretrained("new-llama-tokenizer")
When you call train_from_iterator
you can pass a trainer
object. If you initialize the trainer yourslef you should be able to skip the unknown token
I tried to do it using BPETrainer but then found that in SentencePieceBPETokenizer does not accept trainer object.
Currently, this is how my tokenizer.json
and llama 3 tokenizer json look like:
my current tokenizer:
"model": {
"type": "BPE",
"dropout": null,
"unk_token": null,
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": false,
"vocab": {
"<unk>": 0, # this is a problem
"!": 1,
my tokenizer object:
PreTrainedTokenizerFast(name_or_path='new-llama-tokenizer', vocab_size=128000, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={
0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
. . .
llama 3 tokenizer
"model": {
"type": "BPE",
"dropout": null,
"unk_token": null,
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": true,
"vocab": {
"!": 0,
"\"": 1,
"#": 2,
llama tokenizer object:
PreTrainedTokenizerFast(name_or_path='meta-llama/Llama-3.2-3B', vocab_size=128000, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={
128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
. . .
What I have tried:
added below
trainer = BpeTrainer(
vocab_size=args.vocab_size,
special_tokens=["<|begin_of_text|>", "<|end_of_text|>"],
unk_token=None
)
and tried to pass to
tokenizer.train_from_iterator(dataset['text'], trainer=trainer)
got an error:
TypeError: SentencePieceBPETokenizer.train_from_iterator() got an unexpected keyword argument 'trainer'
I could simply delete that token but I do not think that's a wise solution in my case. would appreciate any ideas.
thank you.
Yeah that is probably because you are using the wrapper around tokenizers
, I think this is only accepted by tokenizer._tokenizer
! Let me check and come back to you!
Yeah, the trainer
is accepted by the Tokenizer
object directly, train new from iterator is the one that intiializes it:
https://github.com/huggingface/transformers/blob/c185608e4209315ea4f02278a318417c1838a421/src/transformers/tokenization_utils_fast.py#L828-L829
In general I think transformers
layer needs a unk_token
but we could also lift that restriction!