How does pyamqp/kombu make multiple threads share the same amqp connection?
ponponon opened this issue · 5 comments
amqp has not only the concept of connection, but also the concept of channel. As I understand it, the channel concept was introduced to allow multiple threads or concurrent threads to share a connection, so that the connection can be multiplexed. For example, if you have a process with 10 threads, you can just create an amqp connection, and then create 10 channels on that amqp connection, so that each thread can hold a separate channel to achieve concurrency and not conflict with each other.
For example, if I have 100 processes, each of which has 10 threads, and if each thread holds a separate amqp connection, the rabbitmq sevrer will need to maintain a total of 1000 amqp connections, which is a lot of pressure, but if you can implement a "multi-threaded or concurrent sharing of a connection, then rabbitmq sevrer will need to maintain a total of 100 amqp connections, which is less stressful.
How to use py-amqp to implement "multiple threads or concurrent threads share a connection to multiplex connections"? Is there a code sample for this? I need a producer code sample and a consumer code sample.
Hi, i'm develop and experiment with that now.
Here my thread-safe transport implementation based on py-amqp:
shared_pyamqp_transport.py
"""
Do not share channel between threads:
- Do not use `default_channel` from different threads,
cause reposnses is channel bound, you can get error from other thread or produce error to different thread
Channel bound to thread where his created.
This required because dispatch frame can raise error which expected in caller thread. E.g.: queue_declare method.
If you declare queue in passive mode RabbitMQ will close channel
and exception MAY raise in different thread when his drain events.
To prevent this current implementation bound channels and dispatch received frames only in their own threads.
TODO:
- ChannelPool: allow transfer channel owner between threads to minimize channel open time when it needed no often.
This helpful for multiple threads which produce messages from time to time and instead creating channel per thread
we can move channel owner
"""
import collections
import contextlib
import functools
import logging
import threading
import amqp
import kombu
import kombu.simple
import kombu.transport.pyamqp
logger = logging.getLogger(__name__)
class ThreadSafeChannel(kombu.transport.pyamqp.Channel):
connection: 'ThreadSafeConnection'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._owner_ident = threading.get_ident()
self.connection.channel_thread_bindings[self._owner_ident].append(self.channel_id)
def collect(self):
conn = self.connection
channel_frame_buff = conn.channel_frame_buff.pop(self.channel_id, ())
if channel_frame_buff:
logger.warning('No drained events after close (%s pending events)', len(channel_frame_buff))
bindings = conn.channel_thread_bindings.get(self._owner_ident) or []
try:
bindings.remove(self.channel_id)
except ValueError:
pass
super().collect()
class DrainGuard:
def __init__(self):
self._drain_cond = threading.Condition()
self._drain_is_active_by = None
self._drain_check_lock = threading.RLock()
def is_drain_active(self):
return self._drain_is_active_by is not None
def start_drain(self):
ctx = contextlib.ExitStack()
if self._drain_is_active_by is None:
# optimization: require lock only when race is possible
# prevent `wait_drain_finished` exiting before drain really started
# It's not important cause `drain_events` calls while inner `promise` obj not ready
# but for correct thread-safe implementation we did it here
ctx.enter_context(self._drain_cond)
with ctx:
acquired = self._drain_check_lock.acquire(blocking=False)
if not acquired:
return False
assert self._drain_is_active_by is None
self._drain_is_active_by = threading.get_ident()
return True
def finish_drain(self):
caller = threading.get_ident()
assert self._drain_is_active_by is not None, 'Drain must be started'
assert self._drain_is_active_by == caller, 'You can not finish drain started by other thread'
with self._drain_cond:
self._drain_is_active_by = None
self._drain_cond.notify_all()
self._drain_check_lock.release()
def wait_drain_finished(self, timeout=None):
caller = threading.get_ident()
assert self._drain_is_active_by != caller, 'You can not wait your own; deadlock detected'
with self._drain_cond:
if self.is_drain_active():
self._drain_cond.wait(timeout=timeout)
class ThreadSafeConnection(kombu.transport.pyamqp.Connection):
Channel = ThreadSafeChannel
def __init__(self, *args, **kwargs):
self._transport_lock = threading.RLock()
self._create_channel_lock = threading.RLock()
self._drain_guard = DrainGuard()
self.channel_thread_bindings = collections.defaultdict(list) # thread_ident -> [channel_id, ...]
self.channel_frame_buff = collections.defaultdict(collections.deque) # channel_id: [frame, frame, ...]
# The connection object itself is treated as channel 0
self.channel_thread_bindings[threading.get_ident()].append(0)
super().__init__(*args, **kwargs)
def channel(self, *args, **kwargs):
with self._create_channel_lock:
return super().channel(*args, **kwargs)
def _claim_channel_id(self, channel_id):
with self._create_channel_lock:
return super()._claim_channel_id(channel_id)
def _get_free_channel_id(self):
with self._create_channel_lock:
return super()._get_free_channel_id()
def _dispatch_channel_frames(self, channel_id):
buff = self.channel_frame_buff.get(channel_id, ())
while buff:
method_sig, payload, content = buff.popleft()
self.channels[channel_id].dispatch_method(
method_sig,
payload,
content,
)
def on_inbound_method(self, channel_id, method_sig, payload, content):
if self.channels is None:
raise amqp.exceptions.RecoverableConnectionError('Connection already closed')
# collect all frames to late dispatch (after drain)
self.channel_frame_buff[channel_id].append((method_sig, payload, content))
def connect(self, *args, **kwargs):
with self._transport_lock:
res = super().connect(*args, **kwargs)
return res
@kombu.transport.pyamqp.Connection.frame_writer.setter
def frame_writer(self, frame_writer):
# frame_writer access to socket
# make it thread-safe
@functools.wraps(frame_writer)
def wrapper(*args, **kwargs):
with self._transport_lock:
res = frame_writer(*args, **kwargs)
return res
self._frame_writer = wrapper
def blocking_read(self, timeout=None):
with self._transport_lock:
return super().blocking_read(timeout=timeout)
def collect(self):
with self._transport_lock:
super().collect()
def drain_events(self, timeout=None):
# When all threads go here only one really drain events
# Because this action independent of caller, all events will be dispatched to their channels
started = self._drain_guard.start_drain()
if not started:
self._drain_guard.wait_drain_finished()
else:
try:
with self._transport_lock:
super().drain_events(timeout=timeout)
finally:
self._drain_guard.finish_drain()
me = threading.get_ident()
my_channels = self.channel_thread_bindings[me]
for channel_id in my_channels:
self._dispatch_channel_frames(channel_id)
def install():
kombu.transport.pyamqp.Transport.Connection = ThreadSafeConnection
kombu.transport.pyamqp.SSLTransport.Connection = ThreadSafeConnection
For guarantee using only one connect to consume and produce i wrote this too:
connection_holder.py
import contextlib
import logging
import threading
import dramatiq
import kombu
import kombu.simple
import kombu.transport.pyamqp
from .shared_pyamqp_transport import KombuConnection
from .shared_pyamqp_transport import install as _install_shared_pyamqp_transport
_install_shared_pyamqp_transport()
logger = logging.getLogger(__name__)
def clone_to_threadsafe_connection(connection: kombu.Connection):
"""copy of kombu.Connection.clone method"""
return KombuConnection(**dict(connection._info(resolve=False)))
class AutoReleaseProducer(kombu.Producer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.__connection__ is None:
# preserve connection instance to later ensure
# due channel can miss it after close
self.__connection__ = self._channel.connection.client
def release(self):
if hasattr(self.channel, 'release'):
self.channel.release()
class ConnectionHolder:
def __init__(
self,
connection: kombu.Connection,
*,
connect_max_retries=None,
):
"""
:param connection:
:param connect_max_retries: maximum number of retries trying to re-establish the connection,
if the connection is lost/unavailable.
"""
connection = clone_to_threadsafe_connection(connection)
self.recoverable_connection_errors = connection.recoverable_connection_errors
self.recoverable_channel_errors = connection.recoverable_channel_errors
self._consumer_connection = connection.clone()
self._producer_connection = connection.clone()
self._conn_lock = threading.RLock()
self.connect_max_retries = connect_max_retries
self._consumer_channel_pool = None
self._producer_channel_pool = None
@contextlib.contextmanager
def reraise_as_library_errors(
self,
ConnectionError=dramatiq.ConnectionError, # noqa: N803,A002
ChannelError=dramatiq.ConnectionError, # noqa: N803
):
try:
yield
except (ConnectionError, ChannelError):
raise
except self.recoverable_connection_errors as exc:
raise ConnectionError(str(exc)) from exc
except self.recoverable_channel_errors as exc:
raise ChannelError(str(exc)) from exc
@staticmethod
def on_connection_error(exc, interval):
logging.getLogger('ConnectionHolder').warning(
'Broker connection error, trying again in %s seconds: %r.',
interval,
exc,
exc_info=True,
stack_info=True,
)
def get_consumer_connection(self) -> KombuConnection:
with self._conn_lock:
conn = self._consumer_connection.ensure_connection()
return conn
def get_producer_connection(self) -> KombuConnection:
with self._conn_lock:
conn = self._producer_connection.ensure_connection(errback=self.on_connection_error)
return conn
def acquire_producer(self):
with self._conn_lock:
conn = self.get_producer_connection()
acquire = conn.ensure(
conn,
lambda *a, **kw: conn.default_channel_pool.acquire(*a, **kw),
errback=self.on_connection_error,
)
channel = acquire(block=True, timeout=2)
assert channel.is_open
return AutoReleaseProducer(channel)
def close(self):
if self._consumer_connection.connected:
with self.reraise_as_library_errors():
self._consumer_connection.close()
if self._producer_connection.connected:
with self.reraise_as_library_errors():
self._producer_connection.close()
This tested with 100 threads, but may have bugs, it's still experiment. Let me know if you will use it :)
So, i update snippet. Now it works correctly. Tested with 900 threads with huge consuming and producing. Each thread got message and enqueue same again and current message was acked.