Repeating tokens in TextStreamer
Opened this issue · 0 comments
chandeldivyam commented
Question
import {
AutoTokenizer,
AutoModelForCausalLM,
TextStreamer,
InterruptableStoppingCriteria,
} from "@huggingface/transformers";
class TextGenerationPipeline {
static model = null;
static tokenizer = null;
static streamer = null;
static async getInstance(
progress_callback = null,
model_id = "onnx-community/Phi-3.5-mini-instruct-onnx-web",
) {
this.tokenizer = AutoTokenizer.from_pretrained(model_id, {
progress_callback,
});
this.model = AutoModelForCausalLM.from_pretrained(model_id, {
// dtype: "q4",
dtype: "q4f16",
device: "webgpu",
use_external_data_format: true,
progress_callback,
});
return Promise.all([this.tokenizer, this.model]);
}
}
const stopping_criteria = new InterruptableStoppingCriteria();
let past_key_values_cache = null;
chrome.runtime.onMessage.addListener((request, sender, sendResponse) => {
if (request.action === "initializeLlmModel") {
console.log("setting up llm");
const initialize = async () => {
const [tokenizer, model] = await TextGenerationPipeline.getInstance(
(x) => {
console.log(x);
},
request.model_id,
);
const inputs = tokenizer("a");
const generatedOutput = await model.generate({
...inputs,
max_new_tokens: 1,
});
console.log(generatedOutput);
sendResponse({ status: "success" });
};
initialize();
return true;
}
if (request.action === "generateText") {
console.log("generating text");
async function generateText() {
const [tokenizer, model] = await TextGenerationPipeline.getInstance();
const text_callback_function = (output) => {
console.log(output);
if (output) {
chrome.runtime.sendMessage({
action: "chatMessageChunk",
chunk: output,
});
}
};
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function: text_callback_function,
});
const inputs = tokenizer.apply_chat_template(request.messages, {
add_generation_prompt: true,
return_dict: true,
});
const { past_key_values, sequences } = await model.generate({
...inputs,
past_key_values: past_key_values_cache,
// Sampling
// do_sample: true,
// top_k: 3,
// temperature: 0.2,
max_new_tokens: 1024,
stopping_criteria,
return_dict_in_generate: true,
streamer,
});
past_key_values_cache = past_key_values;
const decoded = tokenizer.batch_decode(sequences, {
skip_special_tokens: false,
});
console.log(decoded);
sendResponse({ generatedOutput: decoded, status: "success" });
}
generateText();
return true;
}
});
In the text_callback_function
it is sending same token multiple times. What could be the reason? I am handling it on the frontend for the time being but was wondering what is the reason? What am I doing wrong here?
Thank you so much for the help in advance!