AbanteAI/spice

Add Support for Prompt Caching in SpiceMessages Class

mentatai opened this issue · 2 comments

Summary

Add support for Anthropic's prompt caching feature to the SpiceMessages class in the spice library. This will enable faster and more cost-efficient API calls by reusing cached prompt prefixes. Additionally, track cache performance metrics to verify cache hits and the number of input tokens cached.

Changes Required

  1. Update SpiceMessages Class:

    • Add a cache argument to the message creation methods.
    • Set the cache_control parameter based on the cache argument.
  2. Modify Message Creation Functions:

    • Update the message creation functions to handle the cache argument.
  3. Track Cache Performance Metrics:

    • Update the get_response method in the Spice class to handle the new API response fields related to caching.
    • Log the number of input tokens cached and verify cache hits using the client.extract_text_and_tokens method.

Implementation Details

1. Update SpiceMessages Class

Modify the SpiceMessages class to include the cache argument:

class SpiceMessages(UserList[SpiceMessage]):
    ...
    def add_message(self, role: Literal["user", "assistant", "system"], content: str, cache: bool = False):
        self.data.append(create_message(role, content, cache))
    
    def add_user_message(self, content: str, cache: bool = False):
        """Appends a user message with the given content."""
        self.data.append(user_message(content, cache))
    
    def add_system_message(self, content: str, cache: bool = False):
        """Appends a system message with the given content."""
        self.data.append(system_message(content, cache))
    
    def add_assistant_message(self, content: str, cache: bool = False):
        """Appends an assistant message with the given content."""
        self.data.append(assistant_message(content, cache))
    ...

2. Modify Message Creation Functions

Update the message creation functions to handle the cache argument:

def create_message(role: Literal["user", "assistant", "system"], content: str, cache: bool = False) -> ChatCompletionMessageParam:
    message = {"role": role, "content": content}
    if cache:
        message["cache_control"] = {"type": "ephemeral"}
    return message

def user_message(content: str, cache: bool = False) -> ChatCompletionUserMessageParam:
    """Creates a user message with the given content."""
    return create_message("user", content, cache)

def system_message(content: str, cache: bool = False) -> ChatCompletionSystemMessageParam:
    """Creates a system message with the given content."""
    return create_message("system", content, cache)

def assistant_message(content: str, cache: bool = False) -> ChatCompletionAssistantMessageParam:
    """Creates an assistant message with the given content."""
    return create_message("assistant", content, cache)

3. Track Cache Performance Metrics

Update the get_response method in the Spice class to handle the new API response fields related to caching:

async def get_response(
    self,
    messages: Collection[SpiceMessage],
    model: Optional[TextModel | str] = None,
    provider: Optional[Provider | str] = None,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    response_format: Optional[ResponseFormat] = None,
    name: Optional[str] = None,
    validator: Optional[Callable[[str], bool]] = None,
    converter: Callable[[str], T] = string_identity,
    streaming_callback: Optional[Callable[[str], None]] = None,
    retries: int = 0,
    retry_strategy: Optional[RetryStrategy[T]] = None,
    cache_control: Optional[Dict[str, Any]] = None,  # New parameter
) -> SpiceResponse[T]:
    ...
    call_args = self._fix_call_args(
        messages, text_model, streaming_callback is not None, temperature, max_tokens, response_format
    )
    ...
    while True:
        ...
        with client.catch_and_convert_errors():
            if streaming_callback is not None:
                stream = await client.get_chat_completion_or_stream(call_args)
                stream = cast(AsyncIterator, stream)
                streaming_spice_response = StreamingSpiceResponse(
                    text_model, call_args, client, stream, None, streaming_callback
                )
                chat_completion = await streaming_spice_response.complete_response()
                text, input_tokens, output_tokens = (
                    chat_completion.text,
                    chat_completion.input_tokens,
                    chat_completion.output_tokens,
                )
            else:
                chat_completion = await client.get_chat_completion_or_stream(call_args)
                text, input_tokens, output_tokens = client.extract_text_and_tokens(chat_completion, call_args)
        
        # Handle cache performance metrics
        cache_creation_input_tokens = chat_completion.usage.get("cache_creation_input_tokens", 0)
        cache_read_input_tokens = chat_completion.usage.get("cache_read_input_tokens", 0)
        print(f"Cache creation input tokens: {cache_creation_input_tokens}")
        print(f"Cache read input tokens: {cache_read_input_tokens}")
        ...

Example Usage

Here's an example of how you might use the updated SpiceMessages class with caching:

from spice import Spice
from spice.spice_message import SpiceMessages

client = Spice()

messages = SpiceMessages(client)
messages.add_system_message("You are an AI assistant tasked with analyzing literary works.", cache=True)
messages.add_user_message("Analyze the major themes in 'Pride and Prejudice'.", cache=True)

response = await client.get_response(messages=messages, model="claude-3-5-sonnet-20240620")
print(response.text)

Acceptance Criteria

  • The SpiceMessages class should support the cache argument.
  • The get_response method should log cache performance metrics.
  • The implementation should be backward compatible and ignore the cache argument for non-Anthropic clients.

I will start working on this issue