diff --git a/python/idsse_common/idsse/common/rabbitmq_rpc.py b/python/idsse_common/idsse/common/rabbitmq_rpc.py new file mode 100644 index 0000000..b1972bb --- /dev/null +++ b/python/idsse_common/idsse/common/rabbitmq_rpc.py @@ -0,0 +1,244 @@ +"""Module for RPC (remote prodedure call, a.k.a. call-and-response) type RabbitMQ communication""" +# ---------------------------------------------------------------------------------- +# Created on Thu Feb 27 2025 +# +# Copyright (c) 2025 Colorado State University. All rights reserved. (1) +# +# Contributors: +# Mackenzie Grimes (1) +# +# ---------------------------------------------------------------------------------- +import logging +import logging.config +import uuid +from collections.abc import Callable +from concurrent.futures import Future +from copy import deepcopy +from typing import NamedTuple + +from pika.channel import Channel +from pika.spec import Basic, BasicProperties + +from .rabbitmq_utils import (Conn, + Consumer, + Exch, + RabbitMqParams, + RabbitMqParamsAndCallback, + Rpc, RabbitMqMessage, + threadsafe_ack, + threadsafe_nack, + blocking_publish) + +logger = logging.getLogger(__name__) + + +class RpcResponse(NamedTuple): + """Data class to specify how result of RPC request should be communicated to the RMQ broker. + Either ack or nack with no requeue (usually a response RabbitMqMessage should be published), + or nack with requeue True (to re-attempt processing). + + Message can be None (and is None by default), meaning request is only acked/nacked without + a response to the awaiting requestor, but this should generally only be used if requeue=True. + """ + message: RabbitMqMessage | None + ack: bool = True + requeue: bool = False + + +# Tech debt: this is a temporary class "alias" to Rpc to match naming convention of +# `RpcConsumer`. Will be deleted (and `Rpc` renamed) after current usages of `Rpc` are migrated +class RpcPublisher(Rpc): + """RabbitMQ RPC (remote procedure call) publishing client, runs in own thread to not block + heartbeat. This class can be used to send "requests" (outbound messages) over RabbitMQ and + block until a "response" (inbound message) comes back from an `RpcConsumer` instance. + All producing to/consuming of different queues and associating requests with their responses + is abstracted away. + + By RabbitMQ convention, RPC uses the built-in Direct Reply-To queue to field responses messages, + which generates a temporary, random queue name for that individual message, rather than + creating its own durable queue. Directing responses to a custom queue is not yet supported. + + The `start()` and `stop()` methods should be called from the same thread that created the + `RpcPublisher` instance. + + Example usage: + ``` + my_client = RpcPublisher(...insert RabbitMQ parameters...) + my_client.start() + + response = my_client.send_message(RabbitMqMessage('{"some": "json"}')) + # blocks while waiting for response + logger.info(f'Got response from external service: {response}') + ``` + """ + # pylint: disable=arguments-renamed,duplicate-code + def send_request(self, request: RabbitMqMessage) -> RabbitMqMessage | None: + """Send message to remote RabbitMQ service using thread-safe RPC. Will block until response + is received back, or timeout occurs. + + Args: + request (RabbitMqMessage): the RabbitMQ message body and (optional) properties to send + as a "request" to the listening RpcConsumer service. + + Returns: + RabbitMqMessage | None: The response message (body and properties), or None on request + timeout or error handling response. + """ + if not self.is_open: + logger.debug('RPC thread not yet initialized. Setting up now') + self.start() + + # generate unique ID to associate our request to external service's response + request_id = str(uuid.uuid4()) + + # send request to external RMQ service, providing unique RPC message ID and + # the queue where it should respond + if request.properties.headers is None: + request.properties.headers = {} + request.properties.headers['rpc'] = request_id + request.properties.reply_to = self._queue.name + + # overwrite routing key (if any) to enforce use of default Exchange and Direct Reply-to + request = RabbitMqMessage(request.body, request.properties, self._exch.route_key) + + # add future to dict where callback can retrieve it and set result + request_future = Future() + self._pending_requests[request_id] = request_future + + logger.debug('Publishing request message to external service with body: %s', request.body) + blocking_publish(self._consumer.channel, + self._exch, + request, + self._queue) + + try: + # block until callback runs (we'll know when the future's result has been changed) + return request_future.result(timeout=self._timeout) + except TimeoutError: + logger.warning('Timed out waiting for response. rpc request_id: %s', request_id) + self._pending_requests.pop(request_id) # stop tracking request Future + return None + except Exception as exc: # pylint: disable=broad-exception-caught + logger.warning('Unexpected response from external service: %s', str(exc)) + self._pending_requests.pop(request_id) # stop tracking request Future + return None + + +class RpcConsumer(): + """Consumer RPC (remote prodecure call) class that serves as the listener to `RpcPublisher` + messages. `RpcConsumer` creates a thread to constantly consume RPC message "requests" emitted + by `RpcPublisher`, form a response, and send back it to the `RpcPublisher` asynchronously. + + Note that RPC by RabbitMQ convention uses built-in Direct Reply-to queue over the default + exchange, and listeners for responses on a temporary queue unique to a given RPC request. + Publishing RpcConsumer responses to a custom, durable queue is not yet supported. + + Example usage: + ``` + def on_receive_request(message: RabbitMqMessage): + logger.info('Got request from external service: %s', message.body) + if message.properties.content_type == 'application/json': + return RpcResponse(RabbitMqMessage('success!'), ack=True) + return RpcResponse(None, ack=False, requeue=True) + + my_consumer = RpcConsumer(, + RmqParams(, ), + on_receive_request) + my_consumer.start() + ``` + """ + def __init__(self, + conn_params: Conn, + rmq_params: RabbitMqParams, + on_request_callback: Callable[[RabbitMqMessage], RpcResponse], + *args, + **kwargs): + """ + Args: + conn_params (Conn): parameters to connect to RabbitMQ server + rmq_params (RabbitMqParams): parameters of RMQ Exchange and Queue where RPC messages + are expected to be received from an `RpcPublisher`. + on_message_callback (Callable[[RabbitMqMesssage], RpcResponse]): a function that + receives an inbound RPC request message, does some work with it, then returns a + RpcResponse, which controls if message should be acked/nacked and some + RabbitMqMessage published back the original requester, or if the request + should be nack'd and requeued to re-attempt processing + """ + self._rmq_params = rmq_params + self._on_request_callback = on_request_callback + + # Start long-running thread to consume any messages from response queue + self._consumer = Consumer(conn_params, + RabbitMqParamsAndCallback(rmq_params, self._on_message), + *args, + **kwargs) + + @property + def is_open(self): + """Returns True if RabbitMQ connection (Consumer) is open and ready to receive messages""" + return self._consumer.is_alive() and self._consumer.channel.is_open + + def start(self): + """Start dedicated threads to asynchronously receive, and send, RPC messages using a new + RabbitMQ connection and channel. Note: this method can be called externally, but it is + not required to use the client. It will automatically call this internally as needed.""" + if not self.is_open: + logger.debug('Starting RPC thread to consume messages') + self._consumer.start() + self._consumer.join() + + def stop(self): + """Unsubscribe to queue and cleanup thread(s)""" + logger.debug('Shutting down RpcConsumer threads') + if not self.is_open: + logger.debug('RpcConsumer threads not running, nothing to cleanup') + return + + # tell Consumer to cleanup RabbitMQ resources and wait for thread to terminate + self._consumer.stop() + self._consumer.join() + + def _on_message(self, + channel: Channel, + method: Basic.Deliver, + properties: BasicProperties, + body: bytes): + """Handle receiving a request message from an `RpcPublisher`. Invoke user-provided callback + to form response body, then send RabbitMQ message over Exchange (likely default) and Queue + (likely a unique Direct Reply-to) that `RpcPublisher` specified in message props. + """ + request = body.decode() + logger.debug('Received request message from external message with body: %s', request) + response = self._on_request_callback(RabbitMqMessage(request, properties, + method.routing_key)) + + if response.ack: + threadsafe_ack(channel, + method.delivery_tag, + lambda: logger.debug('Request %s was acked', + properties.headers.get('rpc'))) + else: + threadsafe_nack(channel, + method.delivery_tag, + lambda: logger.debug('Request %s was nacked', + properties.headers.get('rpc')), + requeue=response.requeue) + + if content := response.message: + # per RabbitMQ RPC convention, always use default Exchange so clients can operate on + # any exchange and RabbitMQ will route the message to their queue correctly + exch = Exch('', 'direct', route_key=properties.reply_to) + logger.info('Publishing response to default exchange, routing key: %s', exch.route_key) + logger.debug('Publishing response to external request service: %s', content.body) + + # tag this response message using the "rpc" UUID from request's headers; required so + # RpcPublisher can associate each request/response pair and resolve all pending Futures + response_props = deepcopy(content.properties) + response_props.reply_to = properties.reply_to + if response_props.headers is None: + response_props.headers = {} + response_props.headers['rpc'] = properties.headers.get('rpc') + + blocking_publish(channel, + exch, + RabbitMqMessage(content.body, response_props, exch.route_key)) diff --git a/python/idsse_common/idsse/common/rabbitmq_utils.py b/python/idsse_common/idsse/common/rabbitmq_utils.py index 4fb6b0c..b7a03f2 100644 --- a/python/idsse_common/idsse/common/rabbitmq_utils.py +++ b/python/idsse_common/idsse/common/rabbitmq_utils.py @@ -10,14 +10,13 @@ # Mackenzie Grimes (2) # # ---------------------------------------------------------------------------------- - import contextvars import logging import logging.config import uuid -from concurrent.futures import Future, ThreadPoolExecutor from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor from functools import partial from threading import Event, Thread from typing import NamedTuple @@ -96,7 +95,7 @@ class RabbitMqMessage(NamedTuple): Data class to hold a RabbitMQ message body, properties, and optional route_key (if outbound) """ body: str - properties: BasicProperties + properties: BasicProperties = BasicProperties() route_key: str | None = None @@ -277,7 +276,7 @@ def blocking_publish(self, publisher is configured to confirm delivery will return False if failed to confirm. """ - return _blocking_publish(self.channel, + return blocking_publish(self.channel, self._exch, RabbitMqMessage(message, properties, route_key), self._queue) @@ -296,6 +295,8 @@ def stop(self): class Rpc: """ + !! DEPRECATED !! Use `rabbitmq_rpc.RpcPublisher` instead. + RabbitMQ RPC (remote procedure call) client, runs in own thread to not block heartbeat. The start() and stop() methods should be called from the same thread that created the instance. @@ -308,12 +309,11 @@ class Rpc: is not yet supported by Rpc. Example usage: - + ``` my_client = RpcClient(...insert params here...) - response = my_client.send_message('{"some": "json"}') # blocks while waiting for response - logger.info(f'Response from external service: {response}') + ``` """ def __init__(self, conn_params: Conn, exch: Exch, timeout: float | None = None): """ @@ -332,18 +332,17 @@ def __init__(self, conn_params: Conn, exch: Exch, timeout: float | None = None): self._pending_requests: dict[str, Future] = {} # Start long-running thread to consume any messages from response queue - self._consumer = Consumer( - conn_params, - RabbitMqParamsAndCallback(RabbitMqParams(Exch('', 'direct'), self._queue), - self._on_response) - ) + self._consumer = Consumer(conn_params, + RabbitMqParamsAndCallback( + RabbitMqParams(Exch('', 'direct'), self._queue), + self._on_response)) @property def is_open(self) -> bool: """Returns True if RabbitMQ connection (Publisher) is open and ready to send messages""" return self._consumer.is_alive() and self._consumer.channel.is_open - def send_request(self, request_body: str | bytes) -> RabbitMqMessage | None: + def send_request(self, request_body: str | bytes) -> RabbitMqMessage | None: # pragma: no cover """Send message to remote RabbitMQ service using thread-safe RPC. Will block until response is received back, or timeout occurs. @@ -368,7 +367,7 @@ def send_request(self, request_body: str | bytes) -> RabbitMqMessage | None: self._pending_requests[request_id] = request_future logger.debug('Publishing request message to external service with body: %s', request_body) - _blocking_publish(self._consumer.channel, + blocking_publish(self._consumer.channel, self._exch, RabbitMqMessage(request_body, properties, self._exch.route_key), self._queue) @@ -410,7 +409,7 @@ def _on_response( channel: Channel, method: Basic.Deliver, properties: BasicProperties, - body: bytes + body: bytes, ): """Handle RabbitMQ message emitted to response queue.""" logger.debug('Received response with routing_key: %s, content_type: %s, message: %i', @@ -439,8 +438,7 @@ def _on_response( def subscribe_to_queue( connection: Conn | BlockingConnection, rmq_params: RabbitMqParams, - on_message_callback: Callable[ - [Channel, Basic.Deliver, BasicProperties, bytes], None], + on_message_callback: Callable[[Channel, Basic.Deliver, BasicProperties, bytes], None], channel: Channel | None = None ) -> tuple[BlockingConnection, BlockingChannel]: """ @@ -576,6 +574,40 @@ def threadsafe_nack( threadsafe_call(channel, lambda: channel.basic_nack(delivery_tag, requeue=requeue)) +def blocking_publish( + channel: BlockingChannel, + exch: Exch, + message_params: RabbitMqMessage, + queue: Queue | None = None, +) -> bool: + """ + Threadsafe, blocking publish on the specified RabbitMQ exch via the provided channel. + Is thread-safe. + + Args: + channel (BlockingChannel): the pika channel to use to publish. + exch (Exch): parameters for the RabbitMQ exchange to publish message to. + message_params (RabbitMqMessage): the message body to publish, plus properties and + queue (optional, Queue | None): parameters for RabbitMQ queue, if message is being + published to a "temporary"/"private" message queue. The published message will be + purged from this queue after its TTL expires. + Default is None (destination queue not private). + Returns: + (bool) True if message published successfully. If the provided queue is confirmed to + confirm delivery, will return False if failed to confirm. + """ + success_flag = [False] + done_event = Event() + threadsafe_call(channel, lambda: _publish(channel, + exch, + message_params, + queue, + success_flag, + done_event)) + done_event.wait() + return success_flag[0] + + def _initialize_exchange_and_queue(channel: Channel, params: RabbitMqParams) -> str: """Declare and bind RabbitMQ exchange and queue using the provided channel. @@ -638,7 +670,6 @@ def _initialize_connection_and_channel( _channel = channel queue_name = _initialize_exchange_and_queue(_channel, params) - return _connection, _channel, queue_name @@ -750,40 +781,6 @@ def _publish(channel: BlockingChannel, done_event.set() -def _blocking_publish( - channel: BlockingChannel, - exch: Exch, - message_params: RabbitMqMessage, - queue: Queue | None = None, -) -> bool: - """ - Threadsafe, blocking publish on the specified RabbitMQ exch via the provided channel. - Is thread-safe. - - Args: - channel (BlockingChannel): the pika channel to use to publish. - exch (Exch): parameters for the RabbitMQ exchange to publish message to. - message_params (RabbitMqMessage): the message body to publish, plus properties and - queue (optional, Queue | None): parameters for RabbitMQ queue, if message is being - published to a "temporary"/"private" message queue. The published message will be - purged from this queue after its TTL expires. - Default is None (destination queue not private). - Returns: - (bool) True if message published successfully. If the provided queue is confirmed to - confirm delivery, will return False if failed to confirm. - """ - success_flag = [False] - done_event = Event() - threadsafe_call(channel, lambda: _publish(channel, - exch, - message_params, - queue, - success_flag, - done_event)) - done_event.wait() - return success_flag[0] - - def _set_context(context): for var, value in context.items(): var.set(value) diff --git a/python/idsse_common/test/test_rabbitmq_rpc.py b/python/idsse_common/test/test_rabbitmq_rpc.py new file mode 100644 index 0000000..929a6f4 --- /dev/null +++ b/python/idsse_common/test/test_rabbitmq_rpc.py @@ -0,0 +1,273 @@ +"""Testing for RabbitMqUtils functions""" +# ------------------------------------------------------------------------------ +# Created on Thu Feb 27 2025 +# +# Copyright (c) 2025 Colorado State University. All rights reserved. (1) +# +# Contributors: +# Mackenzie Grimes (1) +# +# ------------------------------------------------------------------------------ +# pylint: disable=missing-function-docstring,missing-class-docstring,too-few-public-methods +# pylint: disable=redefined-outer-name,unused-argument,protected-access,duplicate-code +import json +from typing import NamedTuple +from unittest.mock import Mock +from uuid import UUID + +from pytest import fixture, MonkeyPatch +from pika import BasicProperties, BlockingConnection +from pika.adapters.blocking_connection import BlockingChannel + +from idsse.common.rabbitmq_utils import DIRECT_REPLY_QUEUE, Queue +from idsse.common.rabbitmq_rpc import (Conn, Consumer, Exch, Future, RabbitMqParams, + RabbitMqMessage, RpcConsumer, RpcPublisher, RpcResponse) + + +# Example data objects +CONN = Conn('localhost', '/', port=5672, username='user', password='password') +RMQ_PARAMS = RabbitMqParams( + Exch('test_criteria_exch', 'topic'), + Queue('test_criteria_queue', '', True, False, True) +) +EXAMPLE_UUID = 'b6591cc7-8b33-4cd3-aa22-408c83ac5e3c' + + +class Method(NamedTuple): + """Mock of pika.frame.Method""" + exchange: str = ' ' + queue: str = '' + delivery_tag: int = 0 + routing_key: str = '' + + +class Frame(NamedTuple): + """Mock of pika.frame.Frame""" + method: Method + + +# fixtures +@fixture +def mock_channel() -> Mock: + """Mock pika.adapters.blocking_connection.BlockingChannel object""" + def mock_queue_declare(queue: str, **_kwargs) -> Method: + return Frame(Method(queue=queue)) # create a usable (mock) Frame using queue name passed + def mock_exch_declare(exchange: str, **_kwargs) -> Method: + return Frame(Method(exchange=exchange)) + + mock_obj = Mock(spec=BlockingChannel, name='MockChannel') + mock_obj.exchange_declare = Mock(side_effect=mock_exch_declare) + mock_obj.queue_declare = Mock(side_effect=mock_queue_declare) + mock_obj.is_open = True + return mock_obj + + +@fixture +def mock_connection(mock_channel: Mock) -> Mock: + """Mock pika.BlockingChannel object""" + mock_obj = Mock(spec=BlockingConnection, name='MockConnection') + mock_obj.channel = Mock(return_value=mock_channel) + return mock_obj + + +@fixture +def mock_consumer(monkeypatch: MonkeyPatch, mock_connection: Mock, mock_channel: Mock) -> Mock: + """Mock rabbitmq_utils.Consumer thread instance""" + mock_obj = Mock(spec=Consumer, name='MockConsumer') + mock_obj.return_value.is_alive = Mock(return_value=False) # by default, thread not running + mock_obj.return_value.connection = mock_connection + mock_obj.return_value.channel = mock_channel + # hack pika add_callback_threadsafe to invoke immediately (hides complexity of threading) + mock_obj.return_value.channel.connection.add_callback_threadsafe = Mock( + side_effect=lambda cb: cb() + ) + + monkeypatch.setattr('idsse.common.rabbitmq_rpc.Consumer', mock_obj) + # temporarily need to mock Consumer out of rabbitmq_utils as well, where deprecated Rpc is + monkeypatch.setattr('idsse.common.rabbitmq_utils.Consumer', mock_obj) + return mock_obj + + +@fixture +def mock_uuid(monkeypatch: MonkeyPatch) -> Mock: + """Always return our example UUID str when UUID() is called""" + mock_obj = Mock() + mock_obj.UUID = Mock(side_effect=lambda: UUID(EXAMPLE_UUID)) + mock_obj.uuid4 = Mock(side_effect=lambda: EXAMPLE_UUID) + monkeypatch.setattr('idsse.common.rabbitmq_rpc.uuid', mock_obj) + return mock_obj + + +@fixture +def rpc_thread(mock_consumer: Mock, mock_uuid: Mock) -> RpcPublisher: + return RpcPublisher(CONN, RMQ_PARAMS.exchange, timeout=5) + + +# tests +def test_rpc_opens_new_connection_and_channel(rpc_thread: RpcPublisher, mock_consumer: Mock): + assert not rpc_thread.is_open + rpc_thread.start() + + mock_consumer.return_value.start.assert_called_once() + mock_consumer.return_value.is_alive = Mock(return_value=True) # Consumer thread would be live + + # stop Rpc client and confirm that Consumer thread was closed + rpc_thread.stop() + mock_consumer.return_value.is_alive = Mock(return_value=False) # Consumer thread would be dead + + assert not rpc_thread.is_open + + mock_consumer.return_value.stop.assert_called_once() + mock_consumer.return_value.join.assert_called_once() + + +def test_stop_does_nothing_if_not_started(rpc_thread: RpcPublisher, mock_consumer: Mock): + # calling stop before starting does nothing + rpc_thread.stop() + mock_consumer.stop.assert_not_called() + + rpc_thread.start() + mock_consumer.return_value.is_alive = Mock(return_value=True) # Consumer thread would be live + + # calling start when already running does nothing + mock_consumer.reset_mock() + rpc_thread.start() + mock_consumer.return_value.start.assert_not_called() + + +def test_send_request_works_without_calling_start(rpc_thread: RpcPublisher, + mock_channel: Mock, + mock_connection: Mock, + mock_consumer: Mock, + monkeypatch: MonkeyPatch): + example_message = {'value': 'hello world'} + + # when client calls blocking_publish, manually invoke response callback with a faked message + # from external service, simulating RMQ call/response + def mock_blocking_publish(*_args, **_kwargs): + # build mock message from imaginary external service + method = Method('', 123) + props = BasicProperties(content_type='application/json', headers={'rpc': EXAMPLE_UUID}) + body = bytes(json.dumps(example_message), encoding='utf-8') + rpc_thread._on_response(mock_channel, method, props, body) + + monkeypatch.setattr('idsse.common.rabbitmq_rpc.blocking_publish', + Mock(side_effect=mock_blocking_publish)) + + result = rpc_thread.send_request(RabbitMqMessage(json.dumps({'fake': 'request message'}))) + assert json.loads(result.body) == example_message + + +def test_send_request_times_out_if_no_response(mock_connection: Mock, + mock_consumer: Mock, + mock_uuid: Mock, + monkeypatch: MonkeyPatch): + # create client with same parameters, except a very short timeout + _thread = RpcPublisher(CONN, RMQ_PARAMS.exchange, timeout=0.01) + + # do nothing on message publish + monkeypatch.setattr('idsse.common.rabbitmq_rpc.blocking_publish', + Mock(side_effect=lambda *_args, **_kwargs: None)) + + result = _thread.send_request(RabbitMqMessage(json.dumps({'data': 123}))) + assert EXAMPLE_UUID not in _thread._pending_requests # request was cleaned up + assert result is None + + +def test_send_requests_returns_none_on_error(rpc_thread: RpcPublisher, mock_channel: Mock): + # pylint: disable=too-many-arguments + def mock_basic_publish(exchange, routing_key, body, properties = None, mandatory = False): + # cause exception for pending request Future + rpc_thread._pending_requests[EXAMPLE_UUID].set_exception(RuntimeError('Something broke')) + mock_channel.basic_publish.side_effect = mock_basic_publish + + result = rpc_thread.send_request(RabbitMqMessage({'data': 123})) + + assert EXAMPLE_UUID not in rpc_thread._pending_requests # request was cleaned up + assert result is None + + +def test_nacks_unrecognized_response(rpc_thread: RpcPublisher, + mock_connection: Mock, + mock_channel: Mock, + monkeypatch: MonkeyPatch): + rpc_thread._pending_requests = {'abcd': Future()} + delivery_tag = 123 + props = BasicProperties(content_type='application/json', headers={'rpc': 'unknown_id'}) + body = bytes(json.dumps({'data': 123}), encoding='utf-8') + + rpc_thread._on_response(mock_channel, Method(delivery_tag=delivery_tag), props, body) + + # unregistered message was nacked + mock_channel.basic_nack.assert_called_with(delivery_tag=delivery_tag, requeue=False) + # pending requests inside Rpc was not touched + assert 'abcd' in rpc_thread._pending_requests + assert not rpc_thread._pending_requests['abcd'].done() + + +def test_send_request_preserves_props(rpc_thread: RpcPublisher, mock_channel: Mock): + # pylint: disable=too-many-arguments + def mock_basic_publish(exchange, routing_key, body, properties = None, mandatory = False): + # cause exception for pending request Future + rpc_thread._pending_requests[EXAMPLE_UUID].set_exception(RuntimeError('Something broke')) + mock_channel.basic_publish.side_effect = mock_basic_publish + + result = rpc_thread.send_request(RabbitMqMessage({'data': 123})) + + assert EXAMPLE_UUID not in rpc_thread._pending_requests # request was cleaned up + assert result is None + + +def test_rpc_consumer_start_stop(mock_consumer: Mock): + mock_consumer.return_value.is_alive.return_value = False + rpc_consumer = RpcConsumer(CONN, RMQ_PARAMS, lambda: None) + + rpc_consumer.start() + mock_consumer.return_value.start.assert_called_once() + + mock_consumer.return_value.is_alive.return_value = True + + rpc_consumer.stop() + mock_consumer.return_value.stop.assert_called_once() + + +def test_rpc_consumer_on_message_ack(mock_channel: Mock, mock_consumer: Mock): + example_response = RabbitMqMessage('{"response": "bar"}', + BasicProperties(content_type='application/json', + correlation_id='some-correlation-id')) + mock_on_request = Mock(return_value=RpcResponse(example_response, ack=True)) + inbound_tag = 7 + inbound_rpc_id = '123' + inbound_reply_to = f'{DIRECT_REPLY_QUEUE}.g1h2AA5yZXBseUA1NTQ3NDU0OQAWaZUAAAAAZ7FK9g==' + inbound_props = BasicProperties(content_type='text/html', + headers={'rpc': inbound_rpc_id}, + reply_to=inbound_reply_to) + inbound_body = bytes('{"request": "foo"}', encoding='utf-8') + rpc_consumer = RpcConsumer(CONN, RMQ_PARAMS, mock_on_request) + + rpc_consumer._on_message(mock_channel, Method('', '', inbound_tag), inbound_props, inbound_body) + + mock_channel.basic_ack.assert_called_once_with(inbound_tag) + assert mock_channel.basic_publish.call_count == 1 + published_args = mock_channel.basic_publish.call_args[1] + assert published_args['body'] == example_response.body + assert published_args['properties'].reply_to == inbound_reply_to + assert published_args['properties'].content_type == 'application/json' + assert published_args['properties'].correlation_id == 'some-correlation-id' + assert published_args['properties'].headers['rpc'] == inbound_rpc_id + + +def test_rpc_consumer_on_message_nack(mock_channel: Mock, mock_consumer: Mock): + example_response = RabbitMqMessage('{"response": "bar"}', + BasicProperties(content_type='application/json')) + mock_on_request = Mock(return_value=RpcResponse(example_response, ack=False, requeue=True)) + inbound_tag = 7 + inbound_props = BasicProperties(content_type='application/json', + headers={'rpc': '123'}, + reply_to=DIRECT_REPLY_QUEUE) + inbound_body = bytes('{"request": "foo"}', encoding='utf-8') + rpc_consumer = RpcConsumer(CONN, RMQ_PARAMS, mock_on_request) + + rpc_consumer._on_message(mock_channel, Method('', '', inbound_tag), inbound_props, inbound_body) + + mock_channel.basic_nack.assert_called_once_with(inbound_tag, requeue=True) diff --git a/python/idsse_common/test/test_rabbitmq_utils.py b/python/idsse_common/test/test_rabbitmq_utils.py index b80a1b1..1156409 100644 --- a/python/idsse_common/test/test_rabbitmq_utils.py +++ b/python/idsse_common/test/test_rabbitmq_utils.py @@ -10,24 +10,21 @@ # # ------------------------------------------------------------------------------ # pylint: disable=missing-function-docstring,missing-class-docstring,too-few-public-methods -# pylint: disable=redefined-outer-name,unused-argument,duplicate-code,protected-access - -import json +# pylint: disable=redefined-outer-name,unused-argument,protected-access from threading import Event from typing import NamedTuple from unittest.mock import MagicMock, Mock, patch, ANY -from uuid import UUID from pytest import fixture, raises, MonkeyPatch -from pika import BasicProperties, BlockingConnection +from pika import BlockingConnection from pika.adapters.blocking_connection import BlockingChannel from pika.exceptions import UnroutableError -from idsse.common.rabbitmq_utils import ( - Conn, Consumer, Exch, Future, Queue, Publisher, RabbitMqParams, RabbitMqParamsAndCallback, - RabbitMqMessage, Rpc, subscribe_to_queue, _publish, _setup_exch_and_queue, - threadsafe_call, threadsafe_ack, threadsafe_nack -) +from idsse.common.rabbitmq_utils import (Conn, Consumer, Exch, Queue, Publisher, RabbitMqParams, + RabbitMqParamsAndCallback, RabbitMqMessage, + subscribe_to_queue, _publish, _setup_exch_and_queue, + threadsafe_call, threadsafe_ack, threadsafe_nack) + # Example data objects CONN = Conn('localhost', '/', port=5672, username='user', password='password') @@ -76,37 +73,6 @@ def mock_connection(mock_channel: Mock) -> Mock: return mock_obj -@fixture -def mock_consumer(monkeypatch: MonkeyPatch, mock_connection: Mock, mock_channel: Mock) -> Mock: - """Mock rabbitmq_utils.Consumer thread instance""" - mock_obj = Mock(spec=Consumer, name='MockConsumer') - mock_obj.return_value.is_alive = Mock(return_value=False) # by default, thread not running - mock_obj.return_value.connection = mock_connection - mock_obj.return_value.channel = mock_channel - # hack pika add_callback_threadsafe to invoke immediately (hides complexity of threading) - mock_obj.return_value.channel.connection.add_callback_threadsafe = Mock( - side_effect=lambda cb: cb() - ) - - monkeypatch.setattr('idsse.common.rabbitmq_utils.Consumer', mock_obj) - return mock_obj - - -@fixture -def mock_uuid(monkeypatch: MonkeyPatch) -> Mock: - """Always return our example UUID str when UUID() is called""" - mock_obj = Mock() - mock_obj.UUID = Mock(side_effect=lambda: UUID(EXAMPLE_UUID)) - mock_obj.uuid4 = Mock(side_effect=lambda: EXAMPLE_UUID) - monkeypatch.setattr('idsse.common.rabbitmq_utils.uuid', mock_obj) - return mock_obj - - -@fixture -def rpc_thread(mock_consumer: Mock, mock_uuid: Mock) -> Rpc: - return Rpc(CONN, RMQ_PARAMS.exchange, timeout=5) - - # tests def test_connection_params_works(monkeypatch: MonkeyPatch, mock_connection: Mock): mock_blocking_connection = Mock(return_value=mock_connection) @@ -254,112 +220,6 @@ def test_simple_publisher(monkeypatch: MonkeyPatch, mock_connection: Mock): assert 'MockChannel.close' in str(mock_threadsafe.call_args[0][1]) -def test_rpc_opens_new_connection_and_channel(rpc_thread: Rpc, mock_consumer: Mock): - assert not rpc_thread.is_open - rpc_thread.start() - - mock_consumer.return_value.start.assert_called_once() - mock_consumer.return_value.is_alive = Mock(return_value=True) # Consumer thread would be live - - # stop Rpc client and confirm that Consumer thread was closed - rpc_thread.stop() - mock_consumer.return_value.is_alive = Mock(return_value=False) # Consumer thread would be dead - - assert not rpc_thread.is_open - - mock_consumer.return_value.stop.assert_called_once() - mock_consumer.return_value.join.assert_called_once() - - -def test_stop_does_nothing_if_not_started(rpc_thread: Rpc, mock_consumer: Mock): - # calling stop before starting does nothing - rpc_thread.stop() - mock_consumer.stop.assert_not_called() - - rpc_thread.start() - mock_consumer.return_value.is_alive = Mock(return_value=True) # Consumer thread would be live - - # calling start when already running does nothing - mock_consumer.reset_mock() - rpc_thread.start() - mock_consumer.return_value.start.assert_not_called() - - -def test_send_request_works_without_calling_start(rpc_thread: Rpc, - mock_channel: Mock, - mock_connection: Mock, - mock_consumer: Mock, - monkeypatch: MonkeyPatch): - example_message = {'value': 'hello world'} - - # when client calls _blocking_publish, manually invoke response callback with a faked message - # from external service, simulating RMQ call/response - # pylint: disable=too-many-arguments - def mock_blocking_publish(*_args, **_kwargs): - # build mock message from imaginary external service - method = Method('', 123) - props = BasicProperties(content_type='application/json', headers={'rpc': EXAMPLE_UUID}) - body = bytes(json.dumps(example_message), encoding='utf-8') - rpc_thread._on_response(mock_channel, method, props, body) - - monkeypatch.setattr('idsse.common.rabbitmq_utils._blocking_publish', - Mock(side_effect=mock_blocking_publish)) - - result = rpc_thread.send_request(json.dumps({'fake': 'request message'})) - assert json.loads(result.body) == example_message - - -def test_send_request_times_out_if_no_response(mock_connection: Mock, - mock_consumer: Mock, - mock_uuid: Mock, - monkeypatch: MonkeyPatch): - # create client with same parameters, except a very short timeout - _thread = Rpc(CONN, RMQ_PARAMS.exchange, timeout=0.01) - - # do nothing on message publish - monkeypatch.setattr('idsse.common.rabbitmq_utils._blocking_publish', - Mock(side_effect=lambda *_args, **_kwargs: None)) - - result = _thread.send_request(json.dumps({'data': 123})) - assert EXAMPLE_UUID not in _thread._pending_requests # request was cleaned up - assert result is None - - -def test_send_requests_returns_none_on_error(rpc_thread: Rpc, - mock_connection: Mock, - monkeypatch: MonkeyPatch): - # pylint: disable=too-many-arguments - def mock_blocking_publish(channel, exch, message_params, queue=None, success_flag=None, - done_event=None): - # cause exception for pending request Future - rpc_thread._pending_requests[EXAMPLE_UUID].set_exception(RuntimeError('Something broke')) - - monkeypatch.setattr('idsse.common.rabbitmq_utils._blocking_publish', - Mock(side_effect=mock_blocking_publish)) - - result = rpc_thread.send_request({'data': 123}) - assert EXAMPLE_UUID not in rpc_thread._pending_requests # request was cleaned up - assert result is None - - -def test_nacks_unrecognized_response(rpc_thread: Rpc, - mock_connection: Mock, - mock_channel: Mock, - monkeypatch: MonkeyPatch): - rpc_thread._pending_requests = {'abcd': Future()} - delivery_tag = 123 - props = BasicProperties(content_type='application/json', headers={'rpc': 'unknown_id'}) - body = bytes(json.dumps({'data': 123}), encoding='utf-8') - - rpc_thread._on_response(mock_channel, Method(delivery_tag=delivery_tag), props, body) - - # unregistered message was nacked - mock_channel.basic_nack.assert_called_with(delivery_tag=delivery_tag, requeue=False) - # pending requests inside Rpc was not touched - assert 'abcd' in rpc_thread._pending_requests - assert not rpc_thread._pending_requests['abcd'].done() - - @fixture def mock_conn_params(): return Conn(