Yakifo/amqtt

Cancelling a publish task will keep the puback waiter indefinitely (QOS > 0)

pazzarpj opened this issue · 1 comments

I'm currently on hbmqtt but am considering porting to amqtt due to this bug. However this bug still exists.

I came across this issue when cancelling publish messages (using wait_for) after 10 seconds if not acknowledged on an embedded system. If the message was not acknowledged in time, the pub ack waiter would stay in memory forever.

Code to reproduce this error below. This will send 1000 messages in parallel. It will cancel a portion of them every loop before they receive the puback from the broker. Then they will print out the log message PUBACK waiter with Id 'x' already done. Eventually, all 65535 message ids will be taken and no new messages will send and the client will get stuck in a while loop forever looking for "next_packet_id" as there are none available.
Requires a mqtt broker on 127.0.0.1 to work

import random
import amqtt.client
import asyncio
from amqtt.mqtt.constants import QOS_1
import logging

log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


async def main():
    host = "127.0.0.1"
    port = 1883
    client_config = {
        "default_qos": QOS_1,
        "auto_reconnect": True,
        "reconnect_max_interval": 60,
        "reconnect_retries": -1,
        "keep_alive": 60,
    }
    client = amqtt.client.MQTTClient(config=client_config)

    await client.connect(f"mqtt://{host}:{port}")
    while True:
        futs = [
            asyncio.create_task(client.publish("test_topic", b"test", qos=QOS_1))
            for _ in range(1000)
        ]
        await asyncio.sleep(0)
        i = 0
        while True:
            if any(not fut.done() for fut in futs):
                # print(f"Fut not done {i}")
                pass
            else:
                break
            await asyncio.sleep(0)
            i += 1
            if i > random.randint(2900, 3005):
                cancelled = 0
                for fut in futs:
                    if not fut.done():
                        fut.cancel()
                        cancelled += 1
                print(f"Breaking after {i}: Cancelled {cancelled}")
                break
        await asyncio.gather(*futs, return_exceptions=True)


if __name__ == "__main__":
    asyncio.run(main())

Unit tests are written on the PR which can reproduce this issue on the master and are solved in the PR

@pytest.mark.asyncio
async def test_cancel_publish_qos1():
    """
    Tests that timeouts on published messages will clean up in flight messages
    """
    data = b"data"
    broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
    await broker.start()
    client_pub = MQTTClient()
    await client_pub.connect("mqtt://127.0.0.1/")
    assert client_pub.session.inflight_out_count == 0
    fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_1))
    assert len(client_pub._handler._puback_waiters) == 0
    while len(client_pub._handler._puback_waiters) == 0 or fut.done():
        await asyncio.sleep(0)
    assert len(client_pub._handler._puback_waiters) == 1
    assert client_pub.session.inflight_out_count == 1
    fut.cancel()
    await asyncio.wait([fut])
    assert len(client_pub._handler._puback_waiters) == 0
    assert client_pub.session.inflight_out_count == 0
    await client_pub.disconnect()
    await broker.shutdown()


@pytest.mark.asyncio
async def test_cancel_publish_qos2_pubrec():
    """
    Tests that timeouts on published messages will clean up in flight messages
    """
    data = b"data"
    broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
    await broker.start()
    client_pub = MQTTClient()
    await client_pub.connect("mqtt://127.0.0.1/")
    assert client_pub.session.inflight_out_count == 0
    fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
    assert len(client_pub._handler._pubrec_waiters) == 0
    while (
        len(client_pub._handler._pubrec_waiters) == 0 or fut.done() or fut.cancelled()
    ):
        await asyncio.sleep(0)
    assert len(client_pub._handler._pubrec_waiters) == 1
    assert client_pub.session.inflight_out_count == 1
    fut.cancel()
    await asyncio.sleep(1)
    await asyncio.wait([fut])
    assert len(client_pub._handler._pubrec_waiters) == 0
    assert client_pub.session.inflight_out_count == 0
    await client_pub.disconnect()
    await broker.shutdown()


@pytest.mark.asyncio
async def test_cancel_publish_qos2_pubcomp():
    """
    Tests that timeouts on published messages will clean up in flight messages
    """
    data = b"data"
    broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
    await broker.start()
    client_pub = MQTTClient()
    await client_pub.connect("mqtt://127.0.0.1/")
    assert client_pub.session.inflight_out_count == 0
    fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
    assert len(client_pub._handler._pubcomp_waiters) == 0
    while len(client_pub._handler._pubcomp_waiters) == 0 or fut.done():
        await asyncio.sleep(0)
    assert len(client_pub._handler._pubcomp_waiters) == 1
    fut.cancel()
    await asyncio.wait([fut])
    assert len(client_pub._handler._pubcomp_waiters) == 0
    assert client_pub.session.inflight_out_count == 0
    await client_pub.disconnect()
    await broker.shutdown()