Question: How to cancel a running task
realitix opened this issue · 3 comments
realitix commented
Hello,
I have a special case to manage and I don't see how to do it. At a given moment, I need to know if a task (I have its ID) is actually in progress on a worker, is that possible?
realitix commented
After further consideration, what I am looking for is the ability to stop an ongoing task. Is it possible ?
s3rius commented
Currently there's no such functionality, but I really do want to define an interface to setup such task interruptors.
I'm open for discussion on that.
realitix commented
I developed a custom receiver for that. If someone wants to do it with redis, here the code:
import asyncio
import uuid
from typing import Any, AsyncGenerator, cast
import anyio
from loguru import logger
from redis.asyncio import Redis
from taskiq.abc.broker import AckableMessage
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.receiver.receiver import QUEUE_DONE, Receiver
from taskiq_redis import ListQueueBroker
# ruff: noqa: ANN401,BLE001,C901
# pylint: skip-file
CANCELLER_KEY = "__cancel_task_id__"
class CancellableListQueueBroker(ListQueueBroker):
def __init__(
self,
*args: Any,
queue_name_cancel: str = "taskiq_cancel",
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.queue_name_cancel = queue_name_cancel
async def listen_canceller(self) -> AsyncGenerator[bytes, None]:
async with Redis(connection_pool=self.connection_pool) as redis_conn:
redis_pubsub_channel = redis_conn.pubsub()
await redis_pubsub_channel.subscribe(self.queue_name_cancel)
async for message in redis_pubsub_channel.listen():
if not message:
continue
if message["type"] != "message":
logger.debug("Received non-message from redis: {}", message)
continue
yield message["data"]
async def cancel_task(self, task_id: uuid.UUID) -> None:
taskiq_message: TaskiqMessage = self._prepare_message(task_id)
broker_message: BrokerMessage = self.formatter.dumps(taskiq_message)
async with Redis(connection_pool=self.connection_pool) as redis_conn:
await redis_conn.publish(self.queue_name_cancel, broker_message.message)
def _prepare_message(self, task_id: uuid.UUID) -> TaskiqMessage:
return TaskiqMessage(
task_id=self.id_generator(),
task_name="canceller",
labels={},
labels_types={},
args=[],
kwargs={CANCELLER_KEY: task_id.hex},
)
class CancellableReceiver(Receiver):
def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tasks: set[asyncio.Task[Any]] = set()
def parse_message(self, message: bytes | AckableMessage) -> TaskiqMessage | None:
message_data = message.data if isinstance(message, AckableMessage) else message
try:
taskiq_msg = self.broker.formatter.loads(message=message_data)
taskiq_msg.parse_labels()
except Exception as exc:
logger.warning(
"Cannot parse message: %s. Skipping execution.\n %s",
message_data,
exc,
exc_info=True,
)
return None
return taskiq_msg
async def listen(self) -> None: # pragma: no cover
if self.run_startup:
await self.broker.startup()
logger.info("Listening started.")
queue: asyncio.Queue[bytes | AckableMessage] = asyncio.Queue()
async with anyio.create_task_group() as gr:
gr.start_soon(self.prefetcher, queue)
gr.start_soon(self.runner, queue)
gr.start_soon(self.runner_canceller)
if self.on_exit is not None:
self.on_exit(self)
async def runner_canceller(
self,
) -> None:
def cancel_task(task_id: str) -> None:
for task in self.tasks:
if task.get_name() == task_id:
if task.cancel():
logger.info("Cancelling task {}", task_id)
else:
logger.warning("Cannot cancel task {}", task_id)
iterator = cast(CancellableListQueueBroker, self.broker).listen_canceller()
while True:
try:
message = await iterator.__anext__()
taskiq_msg = self.parse_message(message)
if not taskiq_msg:
continue
if CANCELLER_KEY in taskiq_msg.kwargs:
cancel_task(taskiq_msg.kwargs[CANCELLER_KEY])
except asyncio.CancelledError:
break
except StopAsyncIteration:
break
async def runner(
self,
queue: asyncio.Queue[bytes | AckableMessage],
) -> None:
def task_cb(task: asyncio.Task[Any]) -> None:
self.tasks.discard(task)
if self.sem is not None:
self.sem.release()
while True:
if self.sem is not None:
await self.sem.acquire()
self.sem_prefetch.release()
message = await queue.get()
if message is QUEUE_DONE:
break
taskiq_msg = self.parse_message(message)
if not taskiq_msg:
continue
task = asyncio.create_task(
self.callback(message=message, raise_err=False),
name=str(taskiq_msg.task_id),
)
self.tasks.add(task)
task.add_done_callback(task_cb)