karpetrosyan/hishel

Cancellation signals caught by hishel?

ratzrattillo opened this issue · 2 comments

In my project I have a method which starts and runs workers in parallel.
Workers are spawned in an asyncio.TaskGroup, having an asyncio.Queue (input_queue) handed to them, from which they pull their work items.
Within the async context of the TaskGroup, I wait for the input_queue to become empty (input_queue.join()). When that happens (meaning that there is no more work to be done) the tasks in the TaskGroup receive a cancellation signal via task.cancel(), after which they stop execution and if all are cancelled, the async context of the TaskGroup is left.

Now, when I use the basic httpx.AsyncHTTPTransport this seems to work fine, but when i use hishel.AsyncCachingTransport, somehow the cancellation of workers never happens, and they keep running, even though the input_queues are already empty.

Is there a possibility, that hishel.AsyncCachingTransport swallows the cancellation signals?

Some code for better understanding:

class WorkerConfig:
    """
    WorkerConfig sets parameters to define worker execution like
    e.g. number of concurrently executing instances of a worker.
    """

    def __init__(self, worker: Worker, num_concurrent: int):
        self.worker = worker
        self.num_concurrent = num_concurrent

    def get_num_concurrent(self) -> int:
        return self.num_concurrent

    def get_worker(self) -> Worker:
        return self.worker


async with asyncio.TaskGroup() as group:
    tasks: set[asyncio.Task] = set()
    for worker_cfg in worker_configs:
        for _ in range(worker_cfg.get_num_concurrent()):
            task = group.create_task(worker_cfg.get_worker().do_work())
            tasks.add(task)
    for worker_cfg in worker_configs:
        await worker_cfg.get_worker().input_queue.join()
    for task in tasks:
        task.cancel()

I was curious about this and tried to reproduce it. So I fleshed out your code more and ended up with the following script:

#!/usr/bin/env python3

import hishel
import httpx

import asyncio


class Worker:
    def __init__(self, input_queue: asyncio.Queue, client: httpx.AsyncClient) -> None:
        self.input_queue = input_queue
        self.client = client

    async def do_work(self) -> None:
        while True:
            url = await self.input_queue.get()
            print(f"Pulled from queue: {url}")
            response = await self.client.get(url)
            if response.is_success:
                print(f"Successful request to {url}")
            else:
                print(f"Unsuccessful request to {url}")
            self.input_queue.task_done()


class WorkerConfig:
    """
    WorkerConfig sets parameters to define worker execution like
    e.g. number of concurrently executing instances of a worker.
    """

    def __init__(self, worker: Worker, num_concurrent: int):
        self.worker = worker
        self.num_concurrent = num_concurrent

    def get_num_concurrent(self) -> int:
        return self.num_concurrent

    def get_worker(self) -> Worker:
        return self.worker


async def work() -> None:
    queue = asyncio.Queue()
    #async with httpx.AsyncClient() as client:
    async with httpx.AsyncClient(
        transport=hishel.AsyncCacheTransport(
            transport=httpx.AsyncHTTPTransport(),
            storage=hishel.AsyncInMemoryStorage()
        )
    ) as client:
        worker_configs = [
            WorkerConfig(
                worker=Worker(input_queue=queue, client=client),
                num_concurrent=2,
            ),
            WorkerConfig(
                worker=Worker(input_queue=queue, client=client),
                num_concurrent=1,
            ),
        ]

        async with asyncio.TaskGroup() as group:
            tasks: set[asyncio.Task] = set()

            for worker_cfg in worker_configs:
                for _ in range(worker_cfg.get_num_concurrent()):
                    task = group.create_task(worker_cfg.get_worker().do_work())
                    tasks.add(task)

            for _ in range(20):
                await queue.put("https://example.com")

            for worker_cfg in worker_configs:
                await worker_cfg.get_worker().input_queue.join()

            for task in tasks:
                task.cancel()


def main() -> None:
    asyncio.run(work())


if __name__ == "__main__":
    main()

For me (Python 3.12 and hishel 0.0.29), it seems to work fine: The workers do their stuff, the tasks are canceled, they exit and the program exits. Can you provide more detail about what exactly you are doing?

Okay, that is interesting. I recently switched from my own worker implementation to ray actors. i currently do not have hishel in use to follow up on this topic right now, which is why i will close the issue. Because it seems like it could have been an implementation issue on my side then.
As soon, as I found time to integrate hishel into the new architecture, i might open the issue again, if it still arises.

Thank you very much for trying it out also!