vllm-project/vllm

OpenAIServingChat cannot be instantiated within a running event loop

schoennenbeck opened this issue · 2 comments

I am working with the OpenAI-serving-engines from the current main branch (python 3.10).

When I try to instantiate an OpenAIServingChat from a coroutine I get the error message AttributeError: 'NoneType' object has no attribute 'chat_template'.

Code Example

Here is some sample code to replicate the problem:

from vllm import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

import asyncio

async def main():
    model = "microsoft/phi-2"
    engine_args = AsyncEngineArgs(model=model)
    engine = AsyncLLMEngine.from_engine_args(engine_args)
    serving_chat = OpenAIServingChat(
        engine,
        served_model=model,
        response_role="assistant",
        chat_template=None,
    )
 

if __name__ == "__main__":
    asyncio.run(main())

If I turn the main-coroutine into a function (just removing the async) and just run it directly (without asyncio) everything works as expected.

Problem Investigation

From what I can tell the problem is as follows:

In the __init__ for OpenAIServing link lines 25ff read:

try:
    event_loop = asyncio.get_running_loop()
except RuntimeError:
    event_loop = None

if event_loop is not None and event_loop.is_running(
):  # If the current is instanced by Ray Serve, there is already a running event loop
    event_loop.create_task(self._post_init())
else:  # When using single vLLM without engine_use_ray
    asyncio.run(self._post_init())

Synchronous Case

In the case of a synchronous main function above we enter the else-portion at the bottom in which case asyncio starts a new event loop, runs self._post_init() in it (which loads the tokenizer) and only returns once that has happened. That means the tokenizer is available when OpenAIServingChat calls self._load_chat_template() link in its __init__.

Asynchronous Case

In the case of an asynchronous-main-coroutine above there already is an event loop. Consequently event_loop.create_task(self._post_init()) is called which schedules the tokenizer-loading to be done at some point in the future. However, we do not hit an await before OpenAIServingChat calls self._load_chat_template() so the loop never gets the chance to actually load the tokenizer so it is not there when self._load_chat_template() tries to access it.

Possible solutions

I am not an expert in asyncio-programming so the only solution I found so far is to make _load_chat_template in OpenAIServingChat async as well and replicate the who event-loop/create_task-logic from OpenAIServing's __init__ for the chat-template-loading in the __init__ of OpenAIServingChat. Experimentally that seems to work; however, this doesn't seem like a good solution since I don't think there is any guarantee on the order in which tasks are run by the event-loop so there still could be scenarios in which the error is triggered.

Edit: This does seem to be the only workable solution. To ensure stuff is run in the correct order _load_chat_template will have to wait until the tokenizer is available, e.g.

async def _load_chat_template(self, chat_template):
  while self.tokenizer is None:
    await asyncio.sleep(.01)
  ...

Additional Observation

Interestingly the error is not triggered when using engine_use_ray=True or workers_use_ray=True in a synchronous-main-function. It appears that at the time of calling the __init__ there is not yet a running event loop so we again hit the working else-case.

Here is the workaround I currently use:

class PatchedOpenAIServingChat(OpenAIServingChat):
    def __init__(
        self,
        engine: AsyncLLMEngine,
        served_model: str,
        response_role: str,
        chat_template=None,
    ):
        super(OpenAIServingChat, self).__init__(engine=engine, served_model=served_model)
        self.response_role = response_role
        try:
            event_loop = asyncio.get_running_loop()
        except RuntimeError:
            event_loop = None

        if event_loop is not None and event_loop.is_running():
            event_loop.create_task(self._load_chat_template(chat_template))
        else:
            asyncio.run(self._load_chat_template(chat_template))

    async def _load_chat_template(self, chat_template):
        # Simply making this function async is usually already enough to give the parent
        # class time to load the tokenizer (so usually no sleeping happens here)
        # However, it feels safer to be explicit about this since asyncio does not
        # guarantee the order in which scheduled tasks are run
        while self.tokenizer is None:
            await asyncio.sleep(0.1)
        return super()._load_chat_template(chat_template)

In principle it could happen that this patch just moves the problem down the line. I.e. we do not hit an await or similar after the initialization and before using the engine in which case neither _load_tokenizer nor _load_chat_template had a chance to run. So the cleanest solution would be to set self._ready = True in _load_chat_template and give the class another method

async def ready():
  while not self._ready:
    await asyncio.sleep(0.1)
  return True

so that we can await serving_chat.ready() in our code before we use it.