sgl-project/sglang

[Feature] Context Caching

Closed this issue · 10 comments

Checklist

Motivation

Allow for maintenance of long caches to improve TTFT on long prompts

Related resources

I would expect it involves not releasing the kv cache.

Another related feature would be to allow a fixed and shared cache across all requests. Perhaps this would be even easier to implement.

Context cache is supported and enabled by default.

It is called RadixAttention in SGLang (https://lmsys.org/blog/2024-01-17-sglang/)
It will take effect automatically. During your prefill, you can actually see the cache hit rate in the log

Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 3, cache hit rate: 23.08%, #running-req: 0, #queue-req: 0

Hi @RonanKMcGovern 😂Found out you're a YouTuber. Would you be interested in creating a video to introduce SGLang? Thanks a lot!

@RonanKMcGovern Great video, it would be even better if you could add the GitHub link for SGLang in the description. Thank you very much!

Sure, will do.

Just one question here on caching because I've been running some tests.

In my tests I run with an input of CONTEXT + SHORT_PROMPT (which I vary for each request, to force the response to vary and not be cachable).

When I run that again a second time, I notice that time to first token decreases (makes sense) but also the time for the remaining tokens decreases. This second part is surprising to me because I would have expected that caching can only help to the extent an entire pre-fix history matches. i.e. surely from SHORT_PROMPT onwards, there is no way to make use of cached prefixes because there are new tokens whose information will propagate forward.

I reviewed radix attention and that seems to confirm my understanding (but not my results). Am I missing something?

Is there some non-cache related factor giving speedup on subsequent requests?
Here are some results:

python sglang-cache.py --diff_qs
Context length: 400000 characters

Running test 1:
2024-08-29 14:46:57,482 - INFO - HTTP Request: POST https://ilwg1hanmwnfh9-8000.proxy.runpod.net/v1/chat/completions "HTTP/1.1 200 OK"
Total Time Taken: 57.44 seconds
Time to First Token: 8.58 seconds
Response Length: 2320 characters
Question: Provide a summary of the text above.
Summary: The text is a transcript of the 2023 Berkshire Hathaway annual meeting, where Warren Buffett and Cha...

Response Length: 2320 characters

Running test 2:
2024-08-29 14:47:53,131 - INFO - HTTP Request: POST https://ilwg1hanmwnfh9-8000.proxy.runpod.net/v1/chat/completions "HTTP/1.1 200 OK"
Total Time Taken: 21.32 seconds
Time to First Token: 4.30 seconds
Response Length: 2585 characters
Question: Analyze the main financial trends and strategies discussed in the text.
Summary: The main financial trends and strategies discussed in the text include:

1. **Berkshire Hathaway's e...

Response Length: 2585 characters

Running test 3:
2024-08-29 14:48:14,346 - INFO - HTTP Request: POST https://ilwg1hanmwnfh9-8000.proxy.runpod.net/v1/chat/completions "HTTP/1.1 200 OK"
Total Time Taken: 21.22 seconds
Time to First Token: 4.26 seconds
Response Length: 2650 characters
Question: Explain the key business principles and philosophies mentioned in the document.
Summary: The document discusses various business principles and philosophies mentioned by Warren Buffett and ...

Response Length: 2650 characters

and here is the test script:

from openai import OpenAI
import os
import time
import argparse
from dotenv import load_dotenv
import uuid
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load environment variables
load_dotenv()

# Retrieve the env variables
model = os.getenv('MODEL')
api_endpoint = os.getenv('API_ENDPOINT')

openai_api_base = api_endpoint + '/v1'

# FYI, a runpod api key only needs to be set if using serverless.
api_key = os.getenv('API_KEY') if os.getenv('API_KEY') is not None else "EMPTY"

# Initialize the OpenAI client
client = OpenAI(
    api_key=api_key,
    base_url=openai_api_base,
)

# Argument parsing
parser = argparse.ArgumentParser(description='Run a test for long or short context.')
parser.add_argument('--context_length', type=int, default=400000,
                    help='Number of characters for the description length (default: 400000)')
parser.add_argument('--diff_qs', action='store_true', help='Use different questions for each test')
args = parser.parse_args()

def chat_completion_request_openai(messages, client):
    start_time = time.time()
    first_token_time = None
    full_response = ""

    try:
        stream = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0,
            max_tokens=500,
            stream=True
        )

        for chunk in stream:
            if chunk.choices[0].delta.content is not None:
                if first_token_time is None:
                    first_token_time = time.time()
                full_response += chunk.choices[0].delta.content

        end_time = time.time()
        total_time = end_time - start_time
        time_to_first_token = first_token_time - start_time if first_token_time else None

        print(f"Total Time Taken: {total_time:.2f} seconds")
        print(f"Time to First Token: {time_to_first_token:.2f} seconds" if time_to_first_token else "Time to First Token: N/A")
        print(f"Response Length: {len(full_response)} characters")

        return full_response
    except Exception as e:
        logging.error(f"An error occurred: {str(e)}")
        return None

# Read and truncate the text file
text_file = 'berkshire23.txt'
with open(text_file, 'r') as file:
    text = file.read()

# Truncate or extend text to match the specified context length
if len(text) > args.context_length:
    modified_text = text[:args.context_length]
else:
    modified_text = text.ljust(args.context_length, ' ')

print(f"Context length: {len(modified_text)} characters")

# Add a unique identifier to the context
unique_id = str(uuid.uuid4())
context = f"[UNIQUE_ID: {unique_id}]\n\n[TEXT_START]\n\n{modified_text}\n\n[TEXT_END]\n\n"

# Define different questions
questions = [
    "Provide a summary of the text above.",
    "Analyze the main financial trends and strategies discussed in the text.",
    "Explain the key business principles and philosophies mentioned in the document."
]

for i in range(1, 4):
    print(f"\nRunning test {i}:")
    if args.diff_qs:
        question = questions[i-1]
    else:
        question = questions[0]
    
    messages = [
        {"role": "user", "content": context + question}
    ]
    
    chat_response = chat_completion_request_openai(messages, client)
    if chat_response:
        print(f"Question: {question}")
        print(f"Summary: {chat_response[:100]}...\n")
        print(f"Response Length: {len(chat_response)} characters")
    else:
        logging.error("Failed to get a response.")

Your understanding is correct. The cache only helps the first token latency and it does not help other tokens.
For your results

  1. Do you warm up the server correctly?
  2. There can be some small runtime variances.

Added an SGLang link to the last video: https://youtu.be/I0ccoL80h9Y

Yeah I've run more inference tests including now an added warmup of similar length (although with a different start-phrase, to ensure the warmup isn't being cached).

Still the second run (post warmup) is quite a bit faster than the first run (post warmup), not just on time to first token. It's as though there are more memory read/write operations occurring during token generation in the first run, and this is not the case in the second. I'll post again if I get better ideas on why.

@RonanKMcGovern Just saw your new video https://x.com/lm_zheng/status/1831247015562674254
Very nice explanation!

You are welcome to join our Slack channel, and we can collaborate on more video ideas later.
https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw