(v3) speed up NoBadWordsLogitsProcessor
Closed this issue · 2 comments
System Info
v3
Environment/Platform
- Website/web-app
- Browser extension
- Server-side (e.g., Node.js, Deno, Bun)
- Desktop app (e.g., Electron)
- Other (e.g., VSCode extension)
Description
consider use a Map in NoBadWordsLogitsProcessor
Reproduction
The NoBadWordsLogitsProcessor class nested loops can be slow to run when you have a bunch of bad words. That is my case for instance on distilgpt2 that has ~800 bad words in its vocabulary.
Building a static Map can speed up the look ups. something like
const bad_words_map = new Map();
for (const bad_word_ids of this.bad_words_ids) {
const key = bad_word_ids.at(-1);
if (!bad_words_map.has(key)) {
bad_words_map.set(key, []);
}
bad_words_map.get(key).push(bad_word_ids.slice(0, -1));
}
and then
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
const ids = input_ids[i];
const last_id = ids.at(-1);
if (bad_words_map.has(last_id)) {
const prefixes = bad_words_map.get(last_id);
for (const prefix of prefixes) {
if (ids.slice(-prefix.length).every((v, idx) => v === prefix[idx])) {
batch_logits_data[last_id] = -Infinity;
break;
}
}
}
}
return logits;
}
Example of generation config with a heavy bad words ids list https://huggingface.co/Mozilla/distilvit/blob/main/generation_config.json
(btw I don't know how to change the label from bug
to enhancement
)
I've done a bit of testing and it's a bit trickier than that, unfortunately. The bad_words_ids
is a list of lists structured as follows:
[
[a], // ALWAYS block a
[b, c], // only block c if preceded by [b]
[d], // ALWAYS block d
[e, f, g], // only block g if preceded by [e, f]
...
]
this means we still need to iterate over the entire list - especially to handle these "single bad words". This code:
for (let j = 1; j <= bad_word_ids.length - 1 && bad_word_ids.length < ids.length; ++j) {
// NOTE: We use != instead of !== to compare bigint and number
// @ts-ignore
if (bad_word_ids.at(-j - 1) != ids.at(-j)) {
// We have found a mismatch
mark = false;
break;
}
}
will check if the tokens before the last in the block list match the last ids, and if not, we won't block the last id in the block list.
The good news is that you shouldn't see a massive difference in performance. For the block list of 800, I only see a ~10ms difference in the unit test I created. For a block list of 100 000, the difference is more noticeable, but I don't see that happening in practice.