diff --git a/README.md b/README.md index 837dad0..c527749 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ A connection to the RabbitMQ AMQP 1.0 server can be established using the Enviro For example: ```python - environment = Environment() - connection = environment.connection("amqp://guest:guest@localhost:5672/") + environment = Environment("amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() ``` diff --git a/examples/getting_started/getting_started.py b/examples/getting_started/getting_started.py index 9960a7b..2fb7c2e 100644 --- a/examples/getting_started/getting_started.py +++ b/examples/getting_started/getting_started.py @@ -50,8 +50,6 @@ def on_message(self, event: Event): if self._count == MESSAGES_TO_PUBLISH: print("closing receiver") # if you want you can add cleanup operations here - # event.receiver.close() - # event.connection.close() def on_connection_closed(self, event: Event): # if you want you can add cleanup operations here @@ -63,7 +61,7 @@ def on_link_closed(self, event: Event) -> None: def create_connection(environment: Environment) -> Connection: - connection = environment.connection("amqp://guest:guest@localhost:5672/") + connection = environment.connection() # in case of SSL enablement # ca_cert_file = ".ci/certs/ca_certificate.pem" # client_cert = ".ci/certs/client_certificate.pem" @@ -87,7 +85,7 @@ def main() -> None: routing_key = "routing-key" print("connection to amqp server") - environment = Environment() + environment = Environment(uri="amqp://guest:guest@localhost:5672/") connection = create_connection(environment) management = connection.management() diff --git a/examples/reconnection/reconnection_example.py b/examples/reconnection/reconnection_example.py index 3cf851b..6504f11 100644 --- a/examples/reconnection/reconnection_example.py +++ b/examples/reconnection/reconnection_example.py @@ -21,8 +21,6 @@ QuorumQueueSpecification, ) -environment = Environment() - # here we keep track of the objects we need to reconnect @dataclass @@ -42,6 +40,7 @@ class ConnectionConfiguration: def on_disconnection(): print("disconnected") + global environment exchange_name = "test-exchange" queue_name = "example-queue" routing_key = "routing-key" @@ -69,6 +68,11 @@ def on_disconnection(): ) +environment = Environment( + uri="amqp://guest:guest@localhost:5672/", on_disconnection_handler=on_disconnection +) + + class MyMessageHandler(AMQPMessagingHandler): def __init__(self): @@ -101,8 +105,6 @@ def on_message(self, event: Event): if self._count == MESSAGES_TO_PUBLSH: print("closing receiver") # if you want you can add cleanup operations here - # event.receiver.close() - # event.connection.close() def on_connection_closed(self, event: Event): # if you want you can add cleanup operations here @@ -122,10 +124,7 @@ def create_connection() -> Connection: # ] # connection = Connection(uris=uris, on_disconnection_handler=on_disconnected) - connection = environment.connection( - uri="amqp://guest:guest@localhost:5672/", - on_disconnection_handler=on_disconnection, - ) + connection = environment.connection() connection.dial() return connection diff --git a/examples/streams/example_with_streams.py b/examples/streams/example_with_streams.py index 49994f5..65a4157 100644 --- a/examples/streams/example_with_streams.py +++ b/examples/streams/example_with_streams.py @@ -54,8 +54,6 @@ def on_message(self, event: Event): if self._count == MESSAGES_TO_PUBLISH: print("closing receiver") # if you want you can add cleanup operations here - # event.receiver.close() - # event.connection.close() def on_connection_closed(self, event: Event): # if you want you can add cleanup operations here @@ -67,7 +65,7 @@ def on_link_closed(self, event: Event) -> None: def create_connection(environment: Environment) -> Connection: - connection = environment.connection("amqp://guest:guest@localhost:5672/") + connection = environment.connection() # in case of SSL enablement # ca_cert_file = ".ci/certs/ca_certificate.pem" # client_cert = ".ci/certs/client_certificate.pem" @@ -88,7 +86,7 @@ def main() -> None: queue_name = "example-queue" print("connection to amqp server") - environment = Environment() + environment = Environment("amqp://guest:guest@localhost:5672/") connection = create_connection(environment) management = connection.management() diff --git a/examples/tls/tls_example.py b/examples/tls/tls_example.py index d118c71..80d419c 100644 --- a/examples/tls/tls_example.py +++ b/examples/tls/tls_example.py @@ -51,8 +51,6 @@ def on_message(self, event: Event): if self._count == messages_to_publish: print("closing receiver") # if you want you can add cleanup operations here - # event.receiver.close() - # event.connection.close() def on_connection_closed(self, event: Event): # if you want you can add cleanup operations here @@ -65,16 +63,7 @@ def on_link_closed(self, event: Event) -> None: def create_connection(environment: Environment) -> Connection: # in case of SSL enablement - ca_cert_file = ".ci/certs/ca_certificate.pem" - client_cert = ".ci/certs/client_certificate.pem" - client_key = ".ci/certs/client_key.pem" - connection = environment.connection( - "amqps://guest:guest@localhost:5671/", - ssl_context=SslConfigurationContext( - ca_cert=ca_cert_file, - client_cert=ClientCert(client_cert=client_cert, client_key=client_key), - ), - ) + connection = environment.connection() connection.dial() return connection @@ -85,7 +74,17 @@ def main() -> None: exchange_name = "test-exchange" queue_name = "example-queue" routing_key = "routing-key" - environment = Environment() + ca_cert_file = ".ci/certs/ca_certificate.pem" + client_cert = ".ci/certs/client_certificate.pem" + client_key = ".ci/certs/client_key.pem" + + environment = Environment( + "amqps://guest:guest@localhost:5671/", + ssl_context=SslConfigurationContext( + ca_cert=ca_cert_file, + client_cert=ClientCert(client_cert=client_cert, client_key=client_key), + ), + ) print("connection to amqp server") connection = create_connection(environment) diff --git a/rabbitmq_amqp_python_client/connection.py b/rabbitmq_amqp_python_client/connection.py index e3a533f..553e9e6 100644 --- a/rabbitmq_amqp_python_client/connection.py +++ b/rabbitmq_amqp_python_client/connection.py @@ -49,8 +49,12 @@ def __init__( Raises: ValueError: If neither uri nor uris is provided """ + if uri is not None and uris is not None: + raise ValueError( + "Cannot specify both 'uri' and 'uris'. Choose one connection mode." + ) if uri is None and uris is None: - raise ValueError("You need to specify at least an addr or a list of addr") + raise ValueError("Must specify either 'uri' or 'uris' for connection.") self._addr: Optional[str] = uri self._addrs: Optional[list[str]] = uris self._conn: BlockingConnection @@ -117,8 +121,15 @@ def close(self) -> None: Closes the underlying connection and removes it from the connection list. """ logger.debug("Closing connection") - self._conn.close() - self._connections.remove(self) + try: + self._conn.close() + except Exception as e: + logger.error(f"Error closing connection: {e}") + raise e + + finally: + if self in self._connections: + self._connections.remove(self) def publisher(self, destination: str = "") -> Publisher: """ diff --git a/rabbitmq_amqp_python_client/environment.py b/rabbitmq_amqp_python_client/environment.py index 337f84c..2a04b42 100644 --- a/rabbitmq_amqp_python_client/environment.py +++ b/rabbitmq_amqp_python_client/environment.py @@ -23,22 +23,40 @@ class Environment: _connections (list[Connection]): List of active connections managed by this environment """ - def __init__(self): # type: ignore + def __init__( + self, # single-node mode + uri: Optional[str] = None, + # multi-node mode + uris: Optional[list[str]] = None, + ssl_context: Optional[SslConfigurationContext] = None, + on_disconnection_handler: Optional[CB] = None, # type: ignore + ): """ Initialize a new Environment instance. Creates an empty list to track active connections. + + Args: + uri: Single node connection URI + uris: List of URIs for multi-node setup + ssl_context: SSL configuration for secure connections + on_disconnection_handler: Callback for handling disconnection events + """ + if uri is not None and uris is not None: + raise ValueError( + "Cannot specify both 'uri' and 'uris'. Choose one connection mode." + ) + if uri is None and uris is None: + raise ValueError("Must specify either 'uri' or 'uris' for connection.") + self._uri = uri + self._uris = uris + self._ssl_context = ssl_context + self._on_disconnection_handler = on_disconnection_handler self._connections: list[Connection] = [] def connection( self, - # single-node mode - uri: Optional[str] = None, - # multi-node mode - uris: Optional[list[str]] = None, - ssl_context: Optional[SslConfigurationContext] = None, - on_disconnection_handler: Optional[CB] = None, # type: ignore ) -> Connection: """ Create and return a new connection. @@ -46,12 +64,6 @@ def connection( This method supports both single-node and multi-node configurations, with optional SSL/TLS security and disconnection handling. - Args: - uri: Single node connection URI - uris: List of URIs for multi-node setup - ssl_context: SSL configuration for secure connections - on_disconnection_handler: Callback for handling disconnection events - Returns: Connection: A new connection instance @@ -59,10 +71,10 @@ def connection( ValueError: If neither uri nor uris is provided """ connection = Connection( - uri=uri, - uris=uris, - ssl_context=ssl_context, - on_disconnection_handler=on_disconnection_handler, + uri=self._uri, + uris=self._uris, + ssl_context=self._ssl_context, + on_disconnection_handler=self._on_disconnection_handler, ) logger.debug("Environment: Creating and returning a new connection") self._connections.append(connection) diff --git a/tests/conftest.py b/tests/conftest.py index 1211d0d..d772f78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ @pytest.fixture() def environment(pytestconfig): - environment = Environment() + environment = Environment(uri="amqp://guest:guest@localhost:5672/") try: yield environment @@ -25,8 +25,8 @@ def environment(pytestconfig): @pytest.fixture() def connection(pytestconfig): - environment = Environment() - connection = environment.connection("amqp://guest:guest@localhost:5672/") + environment = Environment(uri="amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() try: yield connection @@ -37,17 +37,18 @@ def connection(pytestconfig): @pytest.fixture() def connection_ssl(pytestconfig): - environment = Environment() ca_cert_file = ".ci/certs/ca_certificate.pem" client_cert = ".ci/certs/client_certificate.pem" client_key = ".ci/certs/client_key.pem" - connection = environment.connection( + + environment = Environment( "amqps://guest:guest@localhost:5671/", ssl_context=SslConfigurationContext( ca_cert=ca_cert_file, client_cert=ClientCert(client_cert=client_cert, client_key=client_key), ), ) + connection = environment.connection() connection.dial() try: yield connection @@ -58,8 +59,8 @@ def connection_ssl(pytestconfig): @pytest.fixture() def management(pytestconfig): - environment = Environment() - connection = environment.connection("amqp://guest:guest@localhost:5672/") + environment = Environment(uri="amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() try: management = connection.management() @@ -71,8 +72,8 @@ def management(pytestconfig): @pytest.fixture() def consumer(pytestconfig): - environment = Environment() - connection = environment.connection("amqp://guest:guest@localhost:5672/") + environment = Environment(uri="amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() try: queue_name = "test-queue" @@ -105,7 +106,6 @@ def on_message(self, event: Event): self.delivery_context.accept(event) self._received = self._received + 1 if self._received == 1000: - event.connection.close() raise ConsumerTestException("consumed") @@ -123,7 +123,6 @@ def on_message(self, event: Event): self.delivery_context.accept(event) self._received = self._received + 1 if self._received == 10: - event.connection.close() raise ConsumerTestException("consumed") @@ -136,8 +135,6 @@ def __init__(self): def on_message(self, event: Event): self._received = self._received + 1 if self._received == 1000: - event.receiver.close() - event.connection.close() # Workaround to terminate the Consumer and notify the test when all messages are consumed raise ConsumerTestException("consumed") @@ -152,7 +149,6 @@ def on_message(self, event: Event): self.delivery_context.discard(event) self._received = self._received + 1 if self._received == 1000: - event.connection.close() raise ConsumerTestException("consumed") @@ -168,7 +164,6 @@ def on_message(self, event: Event): self.delivery_context.discard_with_annotations(event, annotations) self._received = self._received + 1 if self._received == 1000: - event.connection.close() raise ConsumerTestException("consumed") @@ -182,7 +177,7 @@ def on_message(self, event: Event): self.delivery_context.requeue(event) self._received = self._received + 1 if self._received == 1000: - event.connection.close() + # event.connection.close() raise ConsumerTestException("consumed") @@ -198,7 +193,6 @@ def on_message(self, event: Event): self.delivery_context.requeue_with_annotations(event, annotations) self._received = self._received + 1 if self._received == 1000: - event.connection.close() raise ConsumerTestException("consumed") @@ -214,5 +208,4 @@ def on_message(self, event: Event): self.delivery_context.requeue_with_annotations(event, annotations) self._received = self._received + 1 if self._received == 1000: - event.connection.close() raise ConsumerTestException("consumed") diff --git a/tests/test_connection.py b/tests/test_connection.py index ce6bd1d..ec8b8ac 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,7 +2,6 @@ from rabbitmq_amqp_python_client import ( ClientCert, - Connection, ConnectionClosed, Environment, SslConfigurationContext, @@ -20,30 +19,32 @@ def on_disconnected(): def test_connection() -> None: - environment = Environment() - connection = environment.connection("amqp://guest:guest@localhost:5672/") + environment = Environment(uri="amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() environment.close() def test_environment_context_manager() -> None: - with Environment() as environment: - connection = environment.connection("amqp://guest:guest@localhost:5672/") + with Environment(uri="amqp://guest:guest@localhost:5672/") as environment: + connection = environment.connection() connection.dial() def test_connection_ssl() -> None: - environment = Environment() ca_cert_file = ".ci/certs/ca_certificate.pem" client_cert = ".ci/certs/client_certificate.pem" client_key = ".ci/certs/client_key.pem" - connection = environment.connection( + + environment = Environment( "amqps://guest:guest@localhost:5671/", ssl_context=SslConfigurationContext( ca_cert=ca_cert_file, client_cert=ClientCert(client_cert=client_cert, client_key=client_key), ), ) + + connection = environment.connection() connection.dial() environment.close() @@ -51,12 +52,12 @@ def test_connection_ssl() -> None: def test_environment_connections_management() -> None: - environment = Environment() - connection = environment.connection("amqp://guest:guest@localhost:5672/") + environment = Environment(uri="amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() - connection2 = environment.connection("amqp://guest:guest@localhost:5672/") + connection2 = environment.connection() connection2.dial() - connection3 = environment.connection("amqp://guest:guest@localhost:5672/") + connection3 = environment.connection() connection3.dial() assert environment.active_connections == 3 @@ -89,17 +90,17 @@ def on_disconnected(): # reconnect if connection is not None: - connection = Connection("amqp://guest:guest@localhost:5672/") + connection = environment.connection() connection.dial() nonlocal reconnected reconnected = True - environment = Environment() - - connection = environment.connection( + environment = Environment( "amqp://guest:guest@localhost:5672/", on_disconnection_handler=on_disconnected ) + + connection = environment.connection() connection.dial() # delay diff --git a/tests/test_consumer.py b/tests/test_consumer.py index c02991e..db9bc6a 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -88,7 +88,7 @@ def test_consumer_async_queue_accept( publish_messages(connection, messages_to_send, queue_name) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, message_handler=MyMessageHandlerAccept() @@ -128,7 +128,7 @@ def test_consumer_async_queue_no_ack( publish_messages(connection, messages_to_send, queue_name) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -179,7 +179,7 @@ def test_consumer_async_queue_with_discard( publish_messages(connection, messages_to_send, queue_name) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -236,7 +236,7 @@ def test_consumer_async_queue_with_discard_with_annotations( addr_queue_dl = AddressHelper.queue_address(queue_dead_lettering) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -278,7 +278,6 @@ def test_consumer_async_queue_with_requeue( ) -> None: messages_to_send = 1000 - environment = Environment() queue_name = "test-queue-async-requeue" management = connection.management() @@ -290,7 +289,7 @@ def test_consumer_async_queue_with_requeue( publish_messages(connection, messages_to_send, queue_name) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -329,7 +328,7 @@ def test_consumer_async_queue_with_requeue_with_annotations( publish_messages(connection, messages_to_send, queue_name) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -377,7 +376,7 @@ def test_consumer_async_queue_with_requeue_with_invalid_annotations( publish_messages(connection, messages_to_send, queue_name) # we closed the connection so we need to open a new one - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() try: diff --git a/tests/test_publisher.py b/tests/test_publisher.py index 83e364c..20abe10 100644 --- a/tests/test_publisher.py +++ b/tests/test_publisher.py @@ -265,19 +265,18 @@ def test_disconnection_reconnection() -> None: publisher = None queue_name = "test-queue" connection_test = None - environment = Environment() + environment = None def on_disconnected(): nonlocal publisher nonlocal queue_name nonlocal connection_test + nonlocal environment # reconnect if connection_test is not None: - connection_test = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_test = environment.connection() connection_test.dial() if publisher is not None: @@ -288,9 +287,12 @@ def on_disconnected(): nonlocal reconnected reconnected = True - connection_test = environment.connection( + environment = Environment( "amqp://guest:guest@localhost:5672/", on_disconnection_handler=on_disconnected ) + + connection_test = environment.connection() + connection_test.dial() # delay time.sleep(5) @@ -328,7 +330,7 @@ def on_disconnected(): # cleanup, we need to create a new connection as the previous one # was closed by the test - connection_test = Connection("amqp://guest:guest@localhost:5672/") + connection_test = environment.connection() connection_test.dial() management = connection_test.management() diff --git a/tests/test_streams.py b/tests/test_streams.py index 5011392..535874a 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -32,9 +32,7 @@ def test_stream_read_from_last_default( # consume and then publish try: - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, message_handler=MyMessageHandlerAcceptStreamOffset() @@ -71,9 +69,7 @@ def test_stream_read_from_last( # consume and then publish try: - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, @@ -114,9 +110,7 @@ def test_stream_read_from_offset_zero( stream_filter_options.offset(0) try: - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, @@ -157,9 +151,7 @@ def test_stream_read_from_offset_first( stream_filter_options.offset(OffsetSpecification.first) try: - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, @@ -200,9 +192,7 @@ def test_stream_read_from_offset_ten( stream_filter_options.offset(10) try: - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, @@ -239,9 +229,7 @@ def test_stream_filtering(connection: Connection, environment: Environment) -> N try: stream_filter_options = StreamOptions() stream_filter_options.filter_values(["banana"]) - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -281,9 +269,7 @@ def test_stream_filtering_mixed( try: stream_filter_options = StreamOptions() stream_filter_options.filter_values(["banana"]) - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue, @@ -324,7 +310,7 @@ def test_stream_filtering_not_present( # consume and then publish stream_filter_options = StreamOptions() stream_filter_options.filter_values(["apple"]) - connection_consumer = environment.connection("amqp://guest:guest@localhost:5672/") + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( @@ -367,9 +353,7 @@ def test_stream_match_unfiltered( stream_filter_options = StreamOptions() stream_filter_options.filter_values(["banana"]) stream_filter_options.filter_match_unfiltered(True) - connection_consumer = environment.connection( - "amqp://guest:guest@localhost:5672/" - ) + connection_consumer = environment.connection() connection_consumer.dial() consumer = connection_consumer.consumer( addr_queue,