celery/py-amqp

How does pyamqp/kombu make multiple threads share the same amqp connection?

ponponon opened this issue · 5 comments

amqp has not only the concept of connection, but also the concept of channel. As I understand it, the channel concept was introduced to allow multiple threads or concurrent threads to share a connection, so that the connection can be multiplexed. For example, if you have a process with 10 threads, you can just create an amqp connection, and then create 10 channels on that amqp connection, so that each thread can hold a separate channel to achieve concurrency and not conflict with each other.

For example, if I have 100 processes, each of which has 10 threads, and if each thread holds a separate amqp connection, the rabbitmq sevrer will need to maintain a total of 1000 amqp connections, which is a lot of pressure, but if you can implement a "multi-threaded or concurrent sharing of a connection, then rabbitmq sevrer will need to maintain a total of 100 amqp connections, which is less stressful.

How to use py-amqp to implement "multiple threads or concurrent threads share a connection to multiplex connections"? Is there a code sample for this? I need a producer code sample and a consumer code sample.

spumer commented

Hi, i'm develop and experiment with that now.

Here my thread-safe transport implementation based on py-amqp:

shared_pyamqp_transport.py
"""
Do not share channel between threads:
- Do not use `default_channel` from different threads,
 cause reposnses is channel bound, you can get error from other thread or produce error to different thread

Channel bound to thread where his created.
This required because dispatch frame can raise error which expected in caller thread. E.g.: queue_declare method.
If you declare queue in passive mode RabbitMQ will close channel
    and exception MAY raise in different thread when his drain events.

To prevent this current implementation bound channels and dispatch received frames only in their own threads.

TODO:
 - ChannelPool: allow transfer channel owner between threads to minimize channel open time when it needed no often.
    This helpful for multiple threads which produce messages from time to time and instead creating channel per thread
    we can move channel owner
"""

import collections
import contextlib
import functools
import logging
import threading

import amqp
import kombu
import kombu.simple
import kombu.transport.pyamqp

logger = logging.getLogger(__name__)


class ThreadSafeChannel(kombu.transport.pyamqp.Channel):

    connection: 'ThreadSafeConnection'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._owner_ident = threading.get_ident()
        self.connection.channel_thread_bindings[self._owner_ident].append(self.channel_id)

    def collect(self):
        conn = self.connection
        channel_frame_buff = conn.channel_frame_buff.pop(self.channel_id, ())
        if channel_frame_buff:
            logger.warning('No drained events after close (%s pending events)', len(channel_frame_buff))

        bindings = conn.channel_thread_bindings.get(self._owner_ident) or []
        try:
            bindings.remove(self.channel_id)
        except ValueError:
            pass

        super().collect()


class DrainGuard:
    def __init__(self):
        self._drain_cond = threading.Condition()
        self._drain_is_active_by = None
        self._drain_check_lock = threading.RLock()

    def is_drain_active(self):
        return self._drain_is_active_by is not None

    def start_drain(self):
        ctx = contextlib.ExitStack()
        if self._drain_is_active_by is None:
            # optimization: require lock only when race is possible
            # prevent `wait_drain_finished` exiting before drain really started
            # It's not important cause `drain_events` calls while inner `promise` obj not ready
            # but for correct thread-safe implementation we did it here
            ctx.enter_context(self._drain_cond)

        with ctx:
            acquired = self._drain_check_lock.acquire(blocking=False)
            if not acquired:
                return False

            assert self._drain_is_active_by is None
            self._drain_is_active_by = threading.get_ident()

        return True

    def finish_drain(self):
        caller = threading.get_ident()
        assert self._drain_is_active_by is not None, 'Drain must be started'
        assert self._drain_is_active_by == caller, 'You can not finish drain started by other thread'
        with self._drain_cond:
            self._drain_is_active_by = None
            self._drain_cond.notify_all()
            self._drain_check_lock.release()

    def wait_drain_finished(self, timeout=None):
        caller = threading.get_ident()
        assert self._drain_is_active_by != caller, 'You can not wait your own; deadlock detected'
        with self._drain_cond:
            if self.is_drain_active():
                self._drain_cond.wait(timeout=timeout)


class ThreadSafeConnection(kombu.transport.pyamqp.Connection):
    Channel = ThreadSafeChannel

    def __init__(self, *args, **kwargs):
        self._transport_lock = threading.RLock()

        self._create_channel_lock = threading.RLock()
        self._drain_guard = DrainGuard()
        self.channel_thread_bindings = collections.defaultdict(list)  # thread_ident -> [channel_id, ...]
        self.channel_frame_buff = collections.defaultdict(collections.deque)  # channel_id: [frame, frame, ...]

        # The connection object itself is treated as channel 0
        self.channel_thread_bindings[threading.get_ident()].append(0)

        super().__init__(*args, **kwargs)

    def channel(self, *args, **kwargs):
        with self._create_channel_lock:
            return super().channel(*args, **kwargs)

    def _claim_channel_id(self, channel_id):
        with self._create_channel_lock:
            return super()._claim_channel_id(channel_id)

    def _get_free_channel_id(self):
        with self._create_channel_lock:
            return super()._get_free_channel_id()

    def _dispatch_channel_frames(self, channel_id):
        buff = self.channel_frame_buff.get(channel_id, ())

        while buff:
            method_sig, payload, content = buff.popleft()
            self.channels[channel_id].dispatch_method(
                method_sig,
                payload,
                content,
            )

    def on_inbound_method(self, channel_id, method_sig, payload, content):
        if self.channels is None:
            raise amqp.exceptions.RecoverableConnectionError('Connection already closed')

        # collect all frames to late dispatch (after drain)
        self.channel_frame_buff[channel_id].append((method_sig, payload, content))

    def connect(self, *args, **kwargs):
        with self._transport_lock:
            res = super().connect(*args, **kwargs)
        return res

    @kombu.transport.pyamqp.Connection.frame_writer.setter
    def frame_writer(self, frame_writer):
        # frame_writer access to socket
        # make it thread-safe
        @functools.wraps(frame_writer)
        def wrapper(*args, **kwargs):
            with self._transport_lock:
                res = frame_writer(*args, **kwargs)
            return res

        self._frame_writer = wrapper

    def blocking_read(self, timeout=None):
        with self._transport_lock:
            return super().blocking_read(timeout=timeout)

    def collect(self):
        with self._transport_lock:
            super().collect()

    def drain_events(self, timeout=None):
        # When all threads go here only one really drain events
        # Because this action independent of caller, all events will be dispatched to their channels

        started = self._drain_guard.start_drain()

        if not started:
            self._drain_guard.wait_drain_finished()
        else:
            try:
                with self._transport_lock:
                    super().drain_events(timeout=timeout)
            finally:
                self._drain_guard.finish_drain()

        me = threading.get_ident()
        my_channels = self.channel_thread_bindings[me]
        for channel_id in my_channels:
            self._dispatch_channel_frames(channel_id)


def install():
    kombu.transport.pyamqp.Transport.Connection = ThreadSafeConnection
    kombu.transport.pyamqp.SSLTransport.Connection = ThreadSafeConnection

For guarantee using only one connect to consume and produce i wrote this too:

connection_holder.py
import contextlib
import logging
import threading

import dramatiq
import kombu
import kombu.simple
import kombu.transport.pyamqp

from .shared_pyamqp_transport import KombuConnection
from .shared_pyamqp_transport import install as _install_shared_pyamqp_transport

_install_shared_pyamqp_transport()


logger = logging.getLogger(__name__)


def clone_to_threadsafe_connection(connection: kombu.Connection):
    """copy of kombu.Connection.clone method"""
    return KombuConnection(**dict(connection._info(resolve=False)))


class AutoReleaseProducer(kombu.Producer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if self.__connection__ is None:
            # preserve connection instance to later ensure
            # due channel can miss it after close
            self.__connection__ = self._channel.connection.client

    def release(self):
        if hasattr(self.channel, 'release'):
            self.channel.release()


class ConnectionHolder:
    def __init__(
        self,
        connection: kombu.Connection,
        *,
        connect_max_retries=None,
    ):
        """
        :param connection:
        :param connect_max_retries: maximum number of retries trying to re-establish the connection,
            if the connection is lost/unavailable.
        """
        connection = clone_to_threadsafe_connection(connection)

        self.recoverable_connection_errors = connection.recoverable_connection_errors
        self.recoverable_channel_errors = connection.recoverable_channel_errors

        self._consumer_connection = connection.clone()
        self._producer_connection = connection.clone()
        self._conn_lock = threading.RLock()
        self.connect_max_retries = connect_max_retries
        self._consumer_channel_pool = None
        self._producer_channel_pool = None

    @contextlib.contextmanager
    def reraise_as_library_errors(
        self,
        ConnectionError=dramatiq.ConnectionError,  # noqa: N803,A002
        ChannelError=dramatiq.ConnectionError,  # noqa: N803
    ):
        try:
            yield
        except (ConnectionError, ChannelError):
            raise
        except self.recoverable_connection_errors as exc:
            raise ConnectionError(str(exc)) from exc
        except self.recoverable_channel_errors as exc:
            raise ChannelError(str(exc)) from exc

    @staticmethod
    def on_connection_error(exc, interval):
        logging.getLogger('ConnectionHolder').warning(
            'Broker connection error, trying again in %s seconds: %r.',
            interval,
            exc,
            exc_info=True,
            stack_info=True,
        )

    def get_consumer_connection(self) -> KombuConnection:
        with self._conn_lock:
            conn = self._consumer_connection.ensure_connection()
            return conn

    def get_producer_connection(self) -> KombuConnection:
        with self._conn_lock:
            conn = self._producer_connection.ensure_connection(errback=self.on_connection_error)
            return conn

    def acquire_producer(self):
        with self._conn_lock:
            conn = self.get_producer_connection()
            acquire = conn.ensure(
                conn,
                lambda *a, **kw: conn.default_channel_pool.acquire(*a, **kw),
                errback=self.on_connection_error,
            )
            channel = acquire(block=True, timeout=2)
            assert channel.is_open
            return AutoReleaseProducer(channel)

    def close(self):
        if self._consumer_connection.connected:
            with self.reraise_as_library_errors():
                self._consumer_connection.close()

        if self._producer_connection.connected:
            with self.reraise_as_library_errors():
                self._producer_connection.close()
spumer commented

This tested with 100 threads, but may have bugs, it's still experiment. Let me know if you will use it :)

spumer commented

So, i update snippet. Now it works correctly. Tested with 900 threads with huge consuming and producing. Each thread got message and enqueue same again and current message was acked.

spumer commented

@auvipy hi, what do you think about that? It's workaround implementation and may be you know the right way to implement this in py-amqp/kombu