replicate/replicate-python

Setting timeouts in for replicate.run()

darhsu opened this issue · 2 comments

Hello, I was wondering how you can set timeouts in the replicate.run() function.

I have tried using the replicate client but it didn't throw a timeout error:

from replicate.client import Client

replicate_client = Client(api_token="my_api_token", timeout=1)
replicate_client.run(
        "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
        input={"prompt": "a photo of an astronaut riding a horse on Mars"},
    )

When I printed the timeout, it correctly displayed the timeout value.

print(replicate_client._timeout)
>>> 1

Hi @darhsu. The timeout parameter is passed to the underlying httpx client instance, and configures the timeout for initially creating a prediction and each subsequent request to poll for its completion. All of these are likely to take less than a second.

In general, I'd recommend against this kind of approach. Instead, you should try calling cancel on any predictions that haven't completed before some deadline.

Please note that the model you're running, SDXL typically runs in a few seconds. So a one second timeout would almost always fail to produce results.

@mattt Thanks for your help. Do you have any pointers on canceling async runs?

This is what I have so far, I've added a 60 second timeout to the sample code here.

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]


try:
    async with asyncio.timeout(60):
        results = await asyncio.gather(*tasks)
except TimeoutError:
    # Cancel replicate async run

print(results)