A probably faster way for training the tokenizer (pure Python)
ReinforcedKnowledge opened this issue · 3 comments
Hi!
I'm not sure if this is the appropriate place for posting this, I'm sorry if it is not.
I think there is a way to make the training of the tokenizer faster.
Where does the initial code get slow
I think it's mainly at
Line 33 in f50ad93
Line 39 in f50ad93
I think since we're doing whole passes over all the dictionaries and returning new instances of them is quite costly.
Possible solution
I think it's possible to do some kind of one pass to find the best pair, then update everything in-place. I'll try to explain my code, unfortunately I can't check for the moment how to integrate it as an example or something in your own train
method, but I'll explain at best my inputs and outputs and the code. I think my train_dict
is the equivalent of vocab
in your code, and my pairs_dict
is the equivalent of stats
.
Also, the repository's implementation is byte-level BPE while I'm going to talk about standard BPE, but we can go from one to the other easily.
I'm going with the same corpus as the Huggingface summary about tokenizers: https://huggingface.co/docs/transformers/en/tokenizer_summary : corpus = "hug " * 10 + "pug " * 5 + "pun " * 12 + "bun " * 4 + "hugs " * 5
.
I start with a train_dict
, which will contain the words' frequencies, of the following format:
{0: [10, [('h', 'u'), ('u', 'g'), ('g', 'Ġ')]],
1: [5, [('p', 'u'), ('u', 'g'), ('g', 'Ġ')]],
2: [12, [('p', 'u'), ('u', 'n'), ('n', 'Ġ')]],
3: [4, [('b', 'u'), ('u', 'n'), ('n', 'Ġ')]],
4: [5, [('h', 'u'), ('u', 'g'), ('g', 's'), ('s', 'Ġ')]]}
The structure is the following: word id: List[word frequency, List[word symbols pairs]]
. The idea is, when we find the best pair that leads to the merge rule, the find all the words that contribute to that pair, and update them after the merge accordingly. The frequency is useful because we want to keep track of the pairs' frequencies (which are a sum of the frequencies of the word that contribute to them)
And I also start with the following pairs_dict
which is a dictionary that associates pairs to their frequencies and the words' ids that contribute to the pairs. In our case it is:
{('h', 'u'): [15, {0, 4}],
('u', 'g'): [20, {0, 1, 4}],
('g', 'Ġ'): [15, {0, 1}],
('p', 'u'): [17, {1, 2}],
('u', 'n'): [16, {2, 3}],
('n', 'Ġ'): [16, {2, 3}],
('b', 'u'): [4, {3}],
('g', 's'): [5, {4}],
('s', 'Ġ'): [5, {4}]}
My training loop is simple, find the best pair (the merge rule) and then update both the train_dict
and the pairs_dict
def train_loop(train_dict: Dict, pairs_dict: Dict, num_merges: int) -> None:
for i in range(num_merges):
best = max(pairs_dict, key=lambda pair: pairs_dict[pair][0])
vocab["".join(best)] = base_vocab_size + i
merge_pairs(train_dict, pairs_dict, best)
For the merge_pairs
function, I'm not so proud of it but it does the job 😅
def merge_pairs(words_dict, pairs_dict, max_freq_pair):
max_freq_pair_merged = "".join(max_freq_pair)
for word_id in words_dict:
word_freq = words_dict[word_id][0]
pairs = words_dict[word_id][1]
new_pairs = []
i = 0
while i < len(pairs):
if pairs[i] == max_freq_pair:
# Check for preceding pair
if i > 0 and new_pairs:
prev_pair = new_pairs[-1]
new_pairs[-1] = (prev_pair[0], max_freq_pair_merged)
update_pairs_dict(
pairs_dict, prev_pair, -word_freq, word_id
)
update_pairs_dict(
pairs_dict, new_pairs[-1], word_freq, word_id
)
# Check for following pair
if i < len(pairs) - 1:
next_pair = (max_freq_pair_merged, pairs[i + 1][1])
new_pairs.append(next_pair)
update_pairs_dict(
pairs_dict, pairs[i + 1], -word_freq, word_id
)
update_pairs_dict(
pairs_dict, next_pair, word_freq, word_id
)
i += 1 # Skip the next pair as it's now merged
else:
new_pairs.append(pairs[i])
i += 1
words_dict[word_id][1] = new_pairs
# Delete max_freq_pair from pairs_dict
del pairs_dict[max_freq_pair]
As you can see I do only one pass through the training dictionary and an incomplete pass through the pairs dictionary instead of having two passes, one through the training dictionary and one through the pairs. And the updates are done in-place.
The updates of the pairs_dict
are based on the frequency of the words that participated to the pair.
def update_pairs_dict(pairs_dict, pair, freq_change, word_id):
if pair in pairs_dict:
pairs_dict[pair][0] += freq_change
if freq_change > 0: # If we are adding frequency, add the word ID
pairs_dict[pair][1].add(word_id)
if (
pairs_dict[pair][0] <= 0
): # If frequency is zero or less, delete the pair
del pairs_dict[pair]
else:
pairs_dict[pair] = (
[freq_change, {word_id}] if freq_change > 0 else [freq_change, set()]
)
Context
I know this repository is for educational purposes only and clearly the code in the repo is very concise and clear.
I also know that the focus on making tokenizers faster is on the tokenization part instead of training. Especially since there is Rust code that is way more performant at doing that, and in a distributed fashion as well.
But, I thought maybe someone could benefit from this. I'm trying to implement a bunch of stuff for myself to learn and I have learned a lot by trying to improve my code for better efficiency. I'm not saying my code is perfect, there surely is a way to use some great data structures for this, such linked lists or graphs. I just couldn't do it yet. And I'm aware of the limits of Python compared to Rust (and some other programming languages, but I think Rust is the one most used for tokenizers) for doing distributed compute or for any memory efficient application for the matter. I just thought it'd be great to push the limit on what we can do with tokenizers with pure Python code.
It helps to understand more and more this platform
We can make it better
Yes sure! It can be some kind of example on how to improve a tokenizer's training code algorithmically. Since there is a relationship between the pairs and the words I was thinking of a graph where edges' values are words' frequencies or something, and/or linked lists. But I didn't get the time to completely think about it yet and come up with a better version.
I'll also try to adapt my code to train loop in minBPE and try to benchmark both.
There's also issue #29 that's trying a vectorized version, I haven't tried but it might be a better approach.