sysid/sse-starlette

asyncio.InvalidStateError on async generator when connection is closed with FastAPI

DurandA opened this issue · 3 comments

I am using a simple publish/subscribe pattern with FastAPI in order to broadcast data to clients using SSE:

import asyncio
from fastapi import FastAPI, Request
from sse_starlette.sse import EventSourceResponse


class PubSub:
    def __init__(self):
        self.waiter = asyncio.Future()

    def publish(self, value):
        waiter, self.waiter = self.waiter, asyncio.Future()
        waiter.set_result((value, self.waiter))

    async def subscribe(self):
        waiter = self.waiter
        while True:
            value, waiter = await waiter
            yield value

    __aiter__ = subscribe

pubsub = PubSub()

async def ticker(pubsub):
    counter = 0
    while True:
        pubsub.publish(counter)
        counter += 1
        await asyncio.sleep(1)

app = FastAPI()

@app.on_event("startup")
async def on_startup():    
    asyncio.create_task(ticker(pubsub), name='my_task')

@app.get('/stream')
async def message_stream(request: Request):
    async def event_publisher():
        try:
            while True:
                async for event in pubsub:
                    yield dict(data=event)
        except asyncio.CancelledError as e:
            print(f"Disconnected from client (via refresh/close) {request.client}")
            # Do any other cleanup, if any
            raise e
    return EventSourceResponse(event_publisher())

However, the task "my_task" is somehow killed as soon as the first client disconnects:

Task exception was never retrieved
future: <Task finished name='my_task' coro=<ticker() done, defined at /home/duranda/devel/fastapi-pubsub/main.py:51> exception=InvalidStateError('invalid state')>
Traceback (most recent call last):
  File "/home/duranda/devel/fastapi-pubsub/main.py", line 54, in ticker
    pubsub.publish(counter)
  File "/home/duranda/devel/fastapi-pubsub/main.py", line 38, in publish
    waiter.set_result((value, self.waiter))
asyncio.exceptions.InvalidStateError: invalid state

I also tried with other patterns, such as using AsyncIteratorObserver from aioreactive with the same result: the task linked to the async iterator ends up with an InvalidStateError.

sysid commented

@DurandA I am not sure whether I understand your post properly, but I do not see the direct relation with sse-starlette. If you can be a bit more concrete, please feel free to reopen.

The issue was due to the EventSourceResponse cancelling the task from the asynchronous iterator.

The fix was to "shield" the iterator as follows:

async def event_publisher():
    aiter = pubsub.__aiter__()
    try:
        while True:
            task = asyncio.create_task(aiter.__anext__())
            event = await asyncio.shield(task)
            yield dict(data=event)
    except asyncio.CancelledError as e:
        print(f"Disconnected from client (via refresh/close) {request.client}")
        # Do any other cleanup, if any
        raise e

I suppose that the task is cancelled here:

async def wrap(func: Callable[[], Coroutine[None, None, None]]) -> None:
await func()
# noinspection PyAsyncCall
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
task_group.start_soon(wrap, partial(self._ping, send))
task_group.start_soon(wrap, self.listen_for_exit_signal)

sysid commented

Thanks for sharing your experience. I am glad that you could find a solution.

Is there anything that can be improved on sse-starlette's side?