Tracking token usage?
oneilsh opened this issue · 6 comments
I see that the API supports .message_token_len()
for an individual ChatMessage; it would be nice to be able query total token usage over the course of a conversation for cost tracking purposes.
I'm not entirely sure the best way to handle it - maybe like a .next_message_tokens_cost(message: ChatMessage)
that would return the total prompt tokens (system + function defs + chat history) plus the tokens in message
that would be incurred? If it could be done over the course of a chat (accumulating after each full round) maybe something like .conversation_history_total_prompt_tokens()
and .conversation_history_total_response_tokens()
so a user could compute a running chat cost?
Thanks for considering, and for developing Kani! It really is the 'right' API interface to tool-enabled LLMs in my opinion :)
Thanks for the kind words! I should note that often times the internals of the LLM providers (in particular OpenAI) are a bit of a mystery, so Kani's token counting is really just a best guess to within a couple of percent.
You have a couple options if you want to track tokens as accurately as possible, which I'll lay out here:
- Overriding
Kani.get_model_completion
- this is the method that the Kani instance uses to go the underlying LLM, and it returns aCompletion
, which includes the prompt token len and completion token len as returned by the engine. You could, for example, addtokens_used_prompt
andtokens_used_completion
attributes in a subclass of Kani and increment those after a super call; this has the disadvantage of being post-hoc counting though. I use a similar approach in one of my projects here: https://github.com/zhudotexe/kanpai/blob/cc603705d353e4e9b9aa3cf9fbb12e3a46652c55/kanpai/base_kani.py#L48- You could also use an estimation like
sum(self.message_token_len(m) for m in await self.get_prompt()) + self.engine.token_reserve + self.engine.function_token_reserve(list(self.functions.values()))
if you wanted a token estimation before sending it to the LLM. The instance caches the token lengths so this won't result in a major slowdown.
- You could also use an estimation like
- Use an external gateway - in our lab we've been trying out Helicone for token counting. If you're using OpenAI, you can integrate it with Kani pretty easily by specifying the
api_base
andheaders
when constructing an OpenAIEngine. I've also been interested in Cloudflare AI Gateway, though I haven't used it yet. These solutions also require a bit more engineering though, and I believe they're also only post-hoc.
I'll have to think a bit more about how to implement an official token counting interface if we decide to - maybe Kani.prompt_len_estimate(msgs: list[ChatMessage]) -> int
to perform the estimation detailed above?
Wonderful, thank you! Post-hoc is fine for my case, I used your first suggestion and it works great. I did need to remember to update the counts manually when calling out to sub-kanis. (Maybe engine-level counting?)
Hello again :) I've started playing with the newer streaming functionality, and it's been nice! However, it looks like get_model_completion
isn't called when streaming, so the override described above isn't processing token counts. Perhaps the new get_model_stream
could be overridden similarly? I suppose I could increment the completion token count as they are yielded, but I'm not sure how to get the prompt token count accurately.
Thanks as always-
Good call out - this is a little less elegant with streaming. The get_model_stream
method is a little lower-level (returning a mixed iterator which is managed by a StreamManager
) with no guarantee that a Completion
will be yielded, or that each yield is exactly one token.
In the current version (1.0.1), your best option is probably to look at await stream.completion()
, which should have the prompt_tokens
and completion_tokens
attributes set, with some caveats:
- for API models, these attributes will only be set if the underlying API returns them in the stream (which was not the case for OpenAI models until recently, I'll patch this in)
- this is a different interface than for the non-streaming case, which is kind of ugly
example:
stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?")
async for token in stream:
print(token, end="")
completion = await stream.completion()
prompt_tokens = completion.prompt_tokens
completion_tokens = completion.completion_tokens
# ...
# msg = await stream.message()
In v1.0.0 I added a private _add_completion_to_history
method that acts like add_to_history
, but is only called on model completions with a Completion
rather than on each message. I'll make this method public since it's called on both stream and non-stream completions and token counting is a good use case for it.
I'll update this thread with new code snippets (probably later today?) once that's done.
As of v1.0.2
, Kani.add_completion_to_history
is called after each completion for both streaming and non-streaming generations. You can implement token counting by overriding it like so:
class TokenCountingKani(Kani):
# ...
async def add_completion_to_history(self, completion):
prompt_tokens = completion.prompt_tokens
completion_tokens = completion.completion_tokens
# ...
return await super().add_completion_to_history(completion)
Note that completion.[prompt|completion]_tokens
might be None for user-implemented engines and llama.cpp streams (abetlen/llama-cpp-python#1498), but it should be present for all kani-bundled OpenAI, Anthropic, and Hugging Face engines.
Wonderful, I will give it a try, thank you!