replicate/replicate-python

Is there any restriction about the number of asynasync tasks?

April0409 opened this issue · 1 comments

hello, look at the following implementation, when my prompts is a 100-list, it occurs error.
the error as shown following:

  • Exception Group Traceback (most recent call last):
    | File "..\Lib\site-packages\IPython\core\interactiveshell.py", line 3575, in run_code
    | await eval(code_obj, self.user_global_ns, self.user_ns)
    | File "..\AppData\Local\Temp\ipykernel_30540\1545999430.py", line 23, in
    | async with asyncio.TaskGroup() as tg:
    | File "..\Lib\asyncio\taskgroups.py", line 145, in aexit
    | raise me from None
    | ExceptionGroup: unhandled errors in a TaskGroup (9 sub-exceptions)

However, if my prompts is a 5-list, it do works.
I don't know what happens, but the token I spend has suddenly increased in 100-level. It make my suck.

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]
counter = 10000
async_generators = await asyncio.gather(*tasks)
for async_generator_iterator in async_generators:
    res = [i async for i in async_generator_iterator]
    filtered_list = [item.strip() for item in res if item.strip()]
    json_str = "".join(filtered_list)
    print(json_str)
    result = json_filter(json_str)
    id = f"{counter}"
    # print(result['analogous_sentence'])
    if(result != None):
        save_text_cont(id,result["analogous_sentence"],save_folder=save_folder,save_filename=save_filename)
    counter += 1

Hi @April0409. You can use an asyncio.BoundedSemaphore to limit the number of concurrent tasks like so:

import asyncio

async def process_batch(batch, tg, semaphore):
    async with semaphore:
        tasks = [
            tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
            for prompt in batch
        ]
        async_generators = await asyncio.gather(*tasks)
        for async_generator_iterator in async_generators:
            # Process the results as before
            ...

async def main():
    batch_size = 100  # Adjust the batch size as needed
    max_concurrent_tasks = 10  # Adjust the maximum number of concurrent tasks

    semaphore = asyncio.BoundedSemaphore(max_concurrent_tasks)

    async with asyncio.TaskGroup() as tg:
        batches = [prompts[i:i+batch_size] for i in range(0, len(prompts), batch_size)]
        tasks = [asyncio.create_task(process_batch(batch, tg, semaphore)) for batch in batches]
        await asyncio.gather(*tasks)

asyncio.run(main())