Add Support for Prompt Caching in SpiceMessages Class
mentatai opened this issue · 2 comments
Summary
Add support for Anthropic's prompt caching feature to the SpiceMessages
class in the spice
library. This will enable faster and more cost-efficient API calls by reusing cached prompt prefixes. Additionally, track cache performance metrics to verify cache hits and the number of input tokens cached.
Changes Required
-
Update
SpiceMessages
Class:- Add a
cache
argument to the message creation methods. - Set the
cache_control
parameter based on thecache
argument.
- Add a
-
Modify Message Creation Functions:
- Update the message creation functions to handle the
cache
argument.
- Update the message creation functions to handle the
-
Track Cache Performance Metrics:
- Update the
get_response
method in theSpice
class to handle the new API response fields related to caching. - Log the number of input tokens cached and verify cache hits using the
client.extract_text_and_tokens
method.
- Update the
Implementation Details
1. Update SpiceMessages
Class
Modify the SpiceMessages
class to include the cache
argument:
class SpiceMessages(UserList[SpiceMessage]):
...
def add_message(self, role: Literal["user", "assistant", "system"], content: str, cache: bool = False):
self.data.append(create_message(role, content, cache))
def add_user_message(self, content: str, cache: bool = False):
"""Appends a user message with the given content."""
self.data.append(user_message(content, cache))
def add_system_message(self, content: str, cache: bool = False):
"""Appends a system message with the given content."""
self.data.append(system_message(content, cache))
def add_assistant_message(self, content: str, cache: bool = False):
"""Appends an assistant message with the given content."""
self.data.append(assistant_message(content, cache))
...
2. Modify Message Creation Functions
Update the message creation functions to handle the cache
argument:
def create_message(role: Literal["user", "assistant", "system"], content: str, cache: bool = False) -> ChatCompletionMessageParam:
message = {"role": role, "content": content}
if cache:
message["cache_control"] = {"type": "ephemeral"}
return message
def user_message(content: str, cache: bool = False) -> ChatCompletionUserMessageParam:
"""Creates a user message with the given content."""
return create_message("user", content, cache)
def system_message(content: str, cache: bool = False) -> ChatCompletionSystemMessageParam:
"""Creates a system message with the given content."""
return create_message("system", content, cache)
def assistant_message(content: str, cache: bool = False) -> ChatCompletionAssistantMessageParam:
"""Creates an assistant message with the given content."""
return create_message("assistant", content, cache)
3. Track Cache Performance Metrics
Update the get_response
method in the Spice
class to handle the new API response fields related to caching:
async def get_response(
self,
messages: Collection[SpiceMessage],
model: Optional[TextModel | str] = None,
provider: Optional[Provider | str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
response_format: Optional[ResponseFormat] = None,
name: Optional[str] = None,
validator: Optional[Callable[[str], bool]] = None,
converter: Callable[[str], T] = string_identity,
streaming_callback: Optional[Callable[[str], None]] = None,
retries: int = 0,
retry_strategy: Optional[RetryStrategy[T]] = None,
cache_control: Optional[Dict[str, Any]] = None, # New parameter
) -> SpiceResponse[T]:
...
call_args = self._fix_call_args(
messages, text_model, streaming_callback is not None, temperature, max_tokens, response_format
)
...
while True:
...
with client.catch_and_convert_errors():
if streaming_callback is not None:
stream = await client.get_chat_completion_or_stream(call_args)
stream = cast(AsyncIterator, stream)
streaming_spice_response = StreamingSpiceResponse(
text_model, call_args, client, stream, None, streaming_callback
)
chat_completion = await streaming_spice_response.complete_response()
text, input_tokens, output_tokens = (
chat_completion.text,
chat_completion.input_tokens,
chat_completion.output_tokens,
)
else:
chat_completion = await client.get_chat_completion_or_stream(call_args)
text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args)
# Handle cache performance metrics
cache_creation_input_tokens = chat_completion.usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = chat_completion.usage.get("cache_read_input_tokens", 0)
print(f"Cache creation input tokens: {cache_creation_input_tokens}")
print(f"Cache read input tokens: {cache_read_input_tokens}")
...
Example Usage
Here's an example of how you might use the updated SpiceMessages
class with caching:
from spice import Spice
from spice.spice_message import SpiceMessages
client = Spice()
messages = SpiceMessages(client)
messages.add_system_message("You are an AI assistant tasked with analyzing literary works.", cache=True)
messages.add_user_message("Analyze the major themes in 'Pride and Prejudice'.", cache=True)
response = await client.get_response(messages=messages, model="claude-3-5-sonnet-20240620")
print(response.text)
Acceptance Criteria
- The
SpiceMessages
class should support thecache
argument. - The
get_response
method should log cache performance metrics. - The implementation should be backward compatible and ignore the
cache
argument for non-Anthropic clients.
I will start working on this issue
done in #108