diff --git a/README.md b/README.md index 6cc0e1e..ac89608 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,68 @@ from memphis.types import Retention, Storage import asyncio ``` +### Quickstart - Producing and Consuming + +The most basic functionaly of memphis is the ability to produce messages to a station and to consume those messages. + +> The Memphis.py SDK uses asyncio for many functions. Make sure to call the following code in an async function: + +```python +async def main(): + ... + +if __name__ == '__main__': + asyncio.run(main()) +``` + +First, a connection to Memphis must be made: + +```python +from memphis import Memphis + +# Connecting to the broker +memphis = Memphis() + +await memphis.connect( + host = "", + username = "", + password = "", + account_id = # For cloud users, at the top of the overview page +) +``` + +Then, to produce a message, call the `memphis.produce` function or create a producer and call its `producer.produce` function: + +```python +await memphis.produce( + station_name="", + producer_name="", + message={ + "id": i, + "chocolates_to_eat": 3 + } +) +``` + +Lastly, to consume this message, call the `memphis.fetch_messages` function or create a consumer and call its `consumer.fetch` function: + +```python +from memphis.message import Message + +messages: list[Message] = await memphis.fetch_messages( + station_name="", + consumer_name="", +) # Type-hint the return here for LSP integration + +for consumed_message in messages: + msg_data = json.loads(consumed_message.get_data()) + print(f"Ate {msg_data['chocolates_to_eat']} chocolates... Yum") + + await consumed_message.ack() +``` + +> Remember to call `memphis.close()` to close the connection. + ### Connecting to Memphis First, we need to create Memphis `object` and then connect with Memphis by using `memphis.connect`. diff --git a/examples/consumer.py b/examples/consumer.py index 39d0c9f..fe2941a 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -1,46 +1,53 @@ -from __future__ import annotations +""" +An example consumer for the Memphis.dev python SDK. +""" import asyncio - -from memphis import Memphis, MemphisConnectError, MemphisError, MemphisHeaderError +import json +from memphis import Memphis +from memphis.message import Message async def main(): - async def msg_handler(msgs, error, _): - try: - for msg in msgs: - print("message: ", msg.get_data()) - await msg.ack() - if error: - print(error) - except (MemphisError, MemphisConnectError, MemphisHeaderError) as e: - print(e) - return - + """ + Async main function used for the asyncio runtime. + """ try: + # Connecting to the broker memphis = Memphis() + await memphis.connect( host="", - username="", - connection_token="", + username="", + password="", + # account_id=, # For cloud users on, at the top of the overview page ) consumer = await memphis.consumer( station_name="", consumer_name="", - consumer_group="", ) - consumer.set_context({"key": "value"}) - consumer.consume(msg_handler) - # Keep your main thread alive so the consumer will keep receiving data - await asyncio.Event().wait() + while True: + messages: list[ + Message + ] = await consumer.fetch() # Type-hint the return here for LSP integration - except (MemphisError, MemphisConnectError) as e: - print(e) + if len(messages) == 0: + continue + for consumed_message in messages: + msg_data = json.loads(consumed_message.get_data()) + + # Do something with the message data + print(msg_data) + await consumed_message.ack() + + except Exception as e: + print(e) finally: - await memphis.close() + if memphis != None: + await memphis.close() if __name__ == "__main__": diff --git a/examples/producer.py b/examples/producer.py index 01dd381..1419892 100644 --- a/examples/producer.py +++ b/examples/producer.py @@ -1,45 +1,38 @@ -from __future__ import annotations +""" +An example producer for the Memphis.dev python SDK. +""" import asyncio - -from memphis import ( - Headers, - Memphis, - MemphisConnectError, - MemphisError, - MemphisHeaderError, - MemphisSchemaError, -) +from memphis import Memphis, MemphisConnectError, MemphisError async def main(): + """ + Async main function used for the asyncio runtime. + """ try: + # Connecting to the broker memphis = Memphis() + await memphis.connect( host="", - username="", - connection_token="", + username="", + password="", + # account_id=, # For cloud users on, at the top of the overview page ) + # Creating a producer and producing a message. + # You can also use the memphis.producer function producer = await memphis.producer( - station_name="", producer_name="" + station_name="", # Matches the station name in memphis cloud + producer_name="", ) - headers = Headers() - headers.add("key", "value") - for i in range(5): - await producer.produce( - bytearray("Message #" + str(i) + ": Hello world", "utf-8"), - headers=headers, - ) # you can send the message parameter as dict as well - - except ( - MemphisError, - MemphisConnectError, - MemphisHeaderError, - MemphisSchemaError, - ) as e: - print(e) + for i in range(10): + await producer.produce(message={"id": i, "chocolates_to_eat": 3}) + + except (MemphisError, MemphisConnectError) as e: + print(e) finally: await memphis.close() diff --git a/memphis/consumer.py b/memphis/consumer.py index 437c03b..7a1fc96 100644 --- a/memphis/consumer.py +++ b/memphis/consumer.py @@ -8,6 +8,7 @@ from memphis.utils import default_error_handler, get_internal_name from memphis.message import Message from memphis.partition_generator import PartitionGenerator +from memphis.exceptions import MemphisErrors class Consumer: @@ -37,7 +38,9 @@ def __init__( self.consumer_group = consumer_group.lower() self.pull_interval_ms = pull_interval_ms self.batch_size = batch_size - self.batch_max_time_to_wait_ms = batch_max_time_to_wait_ms if batch_max_time_to_wait_ms >= 100 else 100 + self.batch_max_time_to_wait_ms = ( + batch_max_time_to_wait_ms if batch_max_time_to_wait_ms >= 100 else 100 + ) self.max_ack_time_ms = max_ack_time_ms self.max_msg_deliveries = max_msg_deliveries self.ping_consumer_interval_ms = 30000 @@ -58,12 +61,16 @@ def __init__( self.loading_thread = None self.t_dls = asyncio.create_task(self.__consume_dls()) - def set_context(self, context): """Set a context (dict) that will be passed to each message handler call.""" self.context = context - def consume(self, callback, consumer_partition_key: str = None, consumer_partition_number: int = -1): + def consume( + self, + callback, + consumer_partition_key: str = None, + consumer_partition_number: int = -1, + ): """ This method starts consuming events from the specified station and invokes the provided callback function for each batch of messages received. @@ -101,16 +108,26 @@ async def main(): asyncio.run(main()) """ self.dls_callback_func = callback - self.t_consume = asyncio.create_task(self.__consume(callback, partition_key=consumer_partition_key, consumer_partition_number=consumer_partition_number)) + self.t_consume = asyncio.create_task( + self.__consume( + callback, + partition_key=consumer_partition_key, + consumer_partition_number=consumer_partition_number, + ) + ) - async def __consume(self, callback, partition_key: str = None, consumer_partition_number: int = -1): + async def __consume( + self, callback, partition_key: str = None, consumer_partition_number: int = -1 + ): partition_number = 1 if consumer_partition_number > 0 and partition_key is not None: - raise MemphisError('Can not use both partition number and partition key') + raise MemphisErrors.PartitionNumberKeyError elif partition_key is not None: partition_number = self.get_partition_from_key(partition_key) elif consumer_partition_number > 0: - self.validate_partition_number(consumer_partition_number, self.inner_station_name) + self.validate_partition_number( + consumer_partition_number, self.inner_station_name + ) partition_number = consumer_partition_number while True: @@ -121,19 +138,25 @@ async def __consume(self, callback, partition_key: str = None, consumer_partitio partition_number = next(self.partition_generator) memphis_messages = [] - msgs = await self.subscriptions[partition_number].fetch(self.batch_size) + msgs = await self.subscriptions[partition_number].fetch( + self.batch_size + ) for msg in msgs: memphis_messages.append( - Message(msg, self.connection, self.consumer_group, self.internal_station_name, partition=partition_number) + Message( + msg, + self.connection, + self.consumer_group, + self.internal_station_name, + partition=partition_number, + ) ) await callback(memphis_messages, None, self.context) await asyncio.sleep(self.pull_interval_ms / 1000) except asyncio.TimeoutError: - await callback( - [], MemphisError("Memphis: TimeoutError"), self.context - ) + await callback([], MemphisErrors.TimeoutError, self.context) continue except Exception as e: if self.connection.is_connection_active: @@ -155,13 +178,25 @@ async def __consume_dls(self): if index_to_insert >= 10000: index_to_insert %= 10000 self.dls_messages.insert( - index_to_insert, Message( - msg, self.connection, self.consumer_group, self.internal_station_name) + index_to_insert, + Message( + msg, + self.connection, + self.consumer_group, + self.internal_station_name, + ), ) self.dls_current_index += 1 if self.dls_callback_func != None: await self.dls_callback_func( - [Message(msg, self.connection, self.consumer_group, self.internal_station_name)], + [ + Message( + msg, + self.connection, + self.consumer_group, + self.internal_station_name, + ) + ], None, self.context, ) @@ -170,7 +205,13 @@ async def __consume_dls(self): await self.dls_callback_func([], MemphisError(str(e)), self.context) return - async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None, consumer_partition_number: int = -1, prefetch: bool = False): + async def fetch( + self, + batch_size: int = 10, + consumer_partition_key: str = None, + consumer_partition_number: int = -1, + prefetch: bool = False, + ): """ Fetch a batch of messages. @@ -181,7 +222,7 @@ async def fetch(self, batch_size: int = 10, consumer_partition_key: str = None, Example: import asyncio - + from memphis import Memphis async def main(host, username, password, station): @@ -189,18 +230,18 @@ async def main(host, username, password, station): await memphis.connect(host=host, username=username, password=password) - + consumer = await memphis.consumer(station_name=station, consumer_name="test-consumer", consumer_group="test-consumer-group") - + while True: batch = await consumer.fetch() print("Recieved {} messages".format(len(batch))) for msg in batch: serialized_record = msg.get_data() print("Message:", serialized_record) - + await memphis.close() if __name__ == '__main__': @@ -208,28 +249,33 @@ async def main(host, username, password, station): username, password, station)) - + """ messages = [] partition_number = 1 if len(self.subscriptions) > 1: if consumer_partition_number > 0 and consumer_partition_key is not None: - raise MemphisError('Can not use both partition number and partition key') + raise MemphisErrors.PartitionNumberKeyError elif consumer_partition_key is not None: partition_number = self.get_partition_from_key(consumer_partition_key) elif consumer_partition_number > 0: - self.validate_partition_number(consumer_partition_number, self.inner_station_name) + self.validate_partition_number( + consumer_partition_number, self.inner_station_name + ) partition_number = consumer_partition_number else: partition_number = next(self.partition_generator) - if prefetch and len(self.cached_messages) > 0: if len(self.cached_messages) >= batch_size: messages = self.cached_messages[:batch_size] self.cached_messages = self.cached_messages[batch_size:] - number_of_messages_to_prefetch = batch_size * 2 - batch_size # calculated for clarity - self.load_messages_to_cache(number_of_messages_to_prefetch, partition_number) + number_of_messages_to_prefetch = ( + batch_size * 2 - batch_size + ) # calculated for clarity + self.load_messages_to_cache( + number_of_messages_to_prefetch, partition_number + ) return messages else: messages = self.cached_messages @@ -239,8 +285,7 @@ async def main(host, username, password, station): if self.connection.is_connection_active: try: if batch_size > self.MAX_BATCH_SIZE or batch_size < 1: - raise MemphisError( - f"Batch size can not be greater than {self.MAX_BATCH_SIZE} or less than 1") + raise MemphisErrors.InvalidBatchSize self.batch_size = batch_size if len(self.dls_messages) > 0: if len(self.dls_messages) <= batch_size: @@ -255,10 +300,19 @@ async def main(host, username, password, station): msgs = await self.subscriptions[partition_number].fetch(batch_size) for msg in msgs: messages.append( - Message(msg,self.connection,self.consumer_group,self.internal_station_name,partition=partition_number)) + Message( + msg, + self.connection, + self.consumer_group, + self.internal_station_name, + partition=partition_number, + ) + ) if prefetch: number_of_messages_to_prefetch = batch_size * 2 - self.load_messages_to_cache(number_of_messages_to_prefetch, partition_number) + self.load_messages_to_cache( + number_of_messages_to_prefetch, partition_number + ) return messages except Exception as e: if "timeout" not in str(e).lower(): @@ -266,15 +320,19 @@ async def main(host, username, password, station): return messages - async def __ping_consumer(self, callback): while True: try: await asyncio.sleep(self.ping_consumer_interval_ms / 1000) station_inner = get_internal_name(self.station_name) consumer_group = get_internal_name(self.consumer_group) - if self.inner_station_name in self.connection.partition_consumers_updates_data: - for p in self.connection.partition_consumers_updates_data[station_inner]["partitions_list"]: + if ( + self.inner_station_name + in self.connection.partition_consumers_updates_data + ): + for p in self.connection.partition_consumers_updates_data[ + station_inner + ]["partitions_list"]: stream_name = f"{station_inner}${str(p)}" await self.connection.broker_connection.consumer_info( stream_name, consumer_group, timeout=30 @@ -306,8 +364,7 @@ async def destroy(self, timeout_retries=5): "connection_id": self.connection.connection_id, "req_version": 1, } - consumer_name = json.dumps( - destroy_consumer_req, indent=2).encode("utf-8") + consumer_name = json.dumps(destroy_consumer_req, indent=2).encode("utf-8") # pylint: disable=protected-access res = await self.connection._request( "$memphis_consumer_destructions", consumer_name, 5, timeout_retries @@ -320,16 +377,14 @@ async def destroy(self, timeout_retries=5): internal_station_name = get_internal_name(self.station_name) if self.connection.schema_updates_data != {}: clients_number = ( - self.connection.clients_per_station.get( - internal_station_name) - 1 + self.connection.clients_per_station.get(internal_station_name) - 1 ) self.connection.clients_per_station[ internal_station_name ] = clients_number if clients_number == 0: - sub = self.connection.schema_updates_subs.get( - internal_station_name) + sub = self.connection.schema_updates_subs.get(internal_station_name) task = self.connection.schema_tasks.get(internal_station_name) if internal_station_name in self.connection.schema_updates_data: del self.connection.schema_updates_data[internal_station_name] @@ -349,26 +404,36 @@ async def destroy(self, timeout_retries=5): def get_partition_from_key(self, key): try: - index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.subscriptions) - return self.connection.partition_consumers_updates_data[self.inner_station_name]["partitions_list"][index] + index = mmh3.hash(key, self.connection.SEED, signed=False) % len( + self.subscriptions + ) + return self.connection.partition_consumers_updates_data[ + self.inner_station_name + ]["partitions_list"][index] except Exception as e: raise e def validate_partition_number(self, partition_number, station_name): - partitions_list = self.connection.partition_consumers_updates_data[station_name]["partitions_list"] + partitions_list = self.connection.partition_consumers_updates_data[ + station_name + ]["partitions_list"] if partitions_list is not None: if partition_number < 1 or partition_number > len(partitions_list): - raise MemphisError("Partition number is out of range") + raise MemphisErrors.PartitionOutOfRange elif partition_number not in partitions_list: - raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + raise MemphisErrors.partition_not_in_station( + partition_number, station_name + ) else: - raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + raise MemphisErrors.partition_not_in_station(partition_number, station_name) def load_messages_to_cache(self, batch_size, partition_number): if not self.loading_thread or not self.loading_thread.is_alive(): asyncio.ensure_future(self.__load_messages(batch_size, partition_number)) async def __load_messages(self, batch_size, partition_number): - new_messages = await self.fetch(batch_size, consumer_partition_number=partition_number) + new_messages = await self.fetch( + batch_size, consumer_partition_number=partition_number + ) if new_messages is not None: self.cached_messages.extend(new_messages) diff --git a/memphis/exceptions.py b/memphis/exceptions.py index 65c8c25..8986b33 100644 --- a/memphis/exceptions.py +++ b/memphis/exceptions.py @@ -1,3 +1,8 @@ +from dataclasses import dataclass + +MAX_BATCH_SIZE = 5000 + + class MemphisError(Exception): def __init__(self, message): message = message.replace("nats", "memphis") @@ -21,3 +26,92 @@ class MemphisSchemaError(MemphisError): class MemphisHeaderError(MemphisError): pass + + +@dataclass +class MemphisErrors: + TimeoutError: MemphisError = MemphisError("Memphis: TimeoutError") + PartitionNumberKeyError: MemphisError = MemphisError( + "Can not use both partition number and partition key" + ) + PartitionOutOfRange: MemphisError = MemphisError("Partition number is out of range") + InvalidBatchSize: MemphisError = MemphisError( + f"Batch size can not be greater than {MAX_BATCH_SIZE} or less than 1" + ) + DeadConnection: MemphisError = MemphisError("Connection is dead") + MissingNameOrStationName: MemphisError = MemphisError( + "name and station name can not be empty" + ) + MissingStationName: MemphisError = MemphisError("station name is missing") + InvalidStationNameTpye: MemphisError = MemphisError( + "station_name should be either string or list of strings" + ) + NonPositiveStartConsumeFromSeq: MemphisError = MemphisError( + "start_consume_from_sequence has to be a positive number" + ) + InvalidMinLasMessagesVal: MemphisError = MemphisError( + "min value for last_messages is -1" + ) + ContainsStartConsumeAndLastMessages: MemphisError = MemphisError( + "Consumer creation options can't contain both start_consume_from_sequence and last_messages" + ) + MissingSchemaName: MemphisError = MemphisError("Schema name cannot be empty") + SchemaNameTooLong: MemphisError = MemphisError( + "Schema name should be under 128 characters" + ) + InvalidSchemaChars: MemphisError = MemphisError( + "Only alphanumeric and the '_', '-', '.' characters are allowed in the schema name" + ) + InvalidSchemaStartChar: MemphisError = MemphisError( + "Schema name cannot start or end with a non-alphanumeric character" + ) + DLSCannotBeDelayed: MemoryError = MemphisError("cannot delay DLS message") + + InvalidKeyInHeader: MemphisHeaderError = MemphisHeaderError( + "Keys in headers should not start with $memphis" + ) + + InvalidConnectionType: MemphisConnectError = MemphisConnectError( + "You have to connect with one of the following methods: connection token / password" + ) + MissingTLSCert: MemphisConnectError = MemphisConnectError( + "Must provide a TLS cert file" + ) + MissingTLSKey: MemphisConnectError = MemphisConnectError( + "Must provide a TLS key file" + ) + MissingTLSCa: MemphisConnectError = MemphisConnectError( + "Must provide a TLS ca file" + ) + + UnsupportedMsgType: MemphisSchemaError = MemphisSchemaError( + "Unsupported message type" + ) + + @staticmethod + def expecting_format(error: Exception, expected_format: str): + return MemphisError(f"Expecting {expected_format} format: " + str(error)) + + @staticmethod + def schema_validation_failed(error): + return MemphisSchemaError("Schema validation has failed: " + str(error)) + + @staticmethod + def schema_msg_mismatch(error: Exception): + return MemphisSchemaError( + f"Deserialization has been failed since the message format does not align with the currently attached schema: {str(error)}" + ) + + @staticmethod + def partition_not_in_station(partition_number, station_name): + return MemphisError( + f"Partition {str(partition_number)} does not exist in station {station_name}" + ) + + @staticmethod + def invalid_schema_type(schema_type): + return MemphisError( + "schema type not supported" + + schema_type + + " is not json, graphql, protobuf or avro" + ) diff --git a/memphis/memphis.py b/memphis/memphis.py index 178df75..63c8b94 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -28,7 +28,7 @@ from google.protobuf.message_factory import MessageFactory from graphql import build_schema as build_graphql_schema from memphis.consumer import Consumer -from memphis.exceptions import MemphisConnectError, MemphisError +from memphis.exceptions import MemphisError, MemphisErrors from memphis.headers import Headers from memphis.producer import Producer from memphis.station import Station @@ -37,6 +37,8 @@ from memphis.partition_generator import PartitionGenerator app_id = str(uuid.uuid4()) + + # pylint: disable=too-many-lines class Memphis: MAX_BATCH_SIZE = 5000 @@ -78,16 +80,14 @@ async def get_msgs_sdk_clients_updates(self, iterable: Iterable): "update" ] elif data["type"] == "remove_station": - self.unset_cached_producer_station(data['station_name']) - self.unset_cached_consumer_station(data['station_name']) + self.unset_cached_producer_station(data["station_name"]) + self.unset_cached_consumer_station(data["station_name"]) except Exception as err: raise MemphisError(err) async def sdk_client_updates_listener(self): try: - sub = await self.broker_manager.subscribe( - "$memphis_sdk_clients_updates" - ) + sub = await self.broker_manager.subscribe("$memphis_sdk_clients_updates") self.update_configurations_sub = sub loop = asyncio.get_event_loop() task = loop.create_task( @@ -101,11 +101,14 @@ async def sdk_client_updates_listener(self): async def get_broker_manager_connection(self, connection_opts): if "user" in connection_opts: + async def ping_error_cb(e): if "authorization violation" not in (str(e)).lower(): print(MemphisError(str(e))) + async def error_cb(e): return + ping_connection_opts = copy.deepcopy(connection_opts) ping_connection_opts["allow_reconnect"] = False ping_connection_opts["error_cb"] = ping_error_cb @@ -121,7 +124,9 @@ async def error_cb(e): except Exception as e: if "authorization violation" in str(e).lower(): try: - if "localhost" in connection_opts['servers']: # for handling bad quality networks like port fwd + if ( + "localhost" in connection_opts["servers"] + ): # for handling bad quality networks like port fwd await asyncio.sleep(1) ping_connection_opts["user"] = self.username ping_connection_opts["error_cb"] = error_cb @@ -133,8 +138,8 @@ async def error_cb(e): else: raise e - if "localhost" in connection_opts['servers']: - await asyncio.sleep(1) # for handling bad quality networks like port fwd + if "localhost" in connection_opts["servers"]: + await asyncio.sleep(1) # for handling bad quality networks like port fwd return await broker.connect(**connection_opts) async def connect( @@ -182,14 +187,17 @@ async def connect( self.connection_id = str(uuid.uuid4()) try: if self.connection_token != "" and self.password != "": - raise MemphisConnectError( - "You have to connect with one of the following methods: connection token / password") + raise MemphisErrors.InvalidConnectionType if self.connection_token == "" and self.password == "": - raise MemphisConnectError( - "You have to connect with one of the following methods: connection token / password") + raise MemphisErrors.InvalidConnectionType + self.broker_manager = None + async def closed_callback(): - if self.broker_manager is not None and self.broker_manager.last_error is not None: + if ( + self.broker_manager is not None + and self.broker_manager.last_error is not None + ): print(MemphisError(str(self.broker_manager.last_error))) connection_opts = { @@ -203,13 +211,12 @@ async def closed_callback(): } if cert_file != "" or key_file != "" or ca_file != "": if cert_file == "": - raise MemphisConnectError("Must provide a TLS cert file") + raise MemphisErrors.MissingTLSCert if key_file == "": - raise MemphisConnectError("Must provide a TLS key file") + raise MemphisErrors.MissingTLSKey if ca_file == "": - raise MemphisConnectError("Must provide a TLS ca file") - ssl_ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH) + raise MemphisErrors.MissingTLSCa + ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) ssl_ctx.load_verify_locations(ca_file) ssl_ctx.load_cert_chain(certfile=cert_file, keyfile=key_file) connection_opts["tls"] = ssl_ctx @@ -217,11 +224,12 @@ async def closed_callback(): if self.connection_token != "": connection_opts["token"] = self.connection_token else: - connection_opts["user"] = self.username + \ - "$" + str(self.account_id) + connection_opts["user"] = self.username + "$" + str(self.account_id) connection_opts["password"] = self.password - self.broker_manager = await self.get_broker_manager_connection(connection_opts) + self.broker_manager = await self.get_broker_manager_connection( + connection_opts + ) await self.sdk_client_updates_listener() self.broker_connection = self.broker_manager.jetstream() self.is_connection_active = True @@ -268,7 +276,7 @@ async def station( """ try: if not self.is_connection_active: - raise MemphisError("Connection is dead") + raise MemphisErrors.DeadConnection if partitions_number == 0: partitions_number = 1 @@ -286,14 +294,17 @@ async def station( }, "username": self.username, "tiered_storage_enabled": tiered_storage_enabled, - "partitions_number" : partitions_number, - "dls_station": dls_station + "partitions_number": partitions_number, + "dls_station": dls_station, } create_station_req_bytes = json.dumps(create_station_req, indent=2).encode( "utf-8" ) err_msg = await self._request( - "$memphis_station_creations", create_station_req_bytes, 20, timeout_retries + "$memphis_station_creations", + create_station_req_bytes, + 20, + timeout_retries, ) err_msg = err_msg.data.decode("utf-8") @@ -326,9 +337,12 @@ async def enforce_schema(self, name, station_name, timeout_retries=5): """ try: if name == "" or station_name == "": - raise MemphisError("name and station name can not be empty") - msg = {"name": name, "station_name": station_name, - "username": self.username} + raise MemphisErrors.MissingNameOrStationName + msg = { + "name": name, + "station_name": station_name, + "username": self.username, + } msg_to_send = json.dumps(msg).encode("utf-8") err_msg = await self._request( "$memphis_schema_attachments", msg_to_send, 20, timeout_retries @@ -349,7 +363,7 @@ async def detach_schema(self, station_name, timeout_retries=5): """ try: if station_name == "": - raise MemphisError("station name is missing") + raise MemphisErrors.MissingStationName msg = {"station_name": station_name, "username": self.username} msg_to_send = json.dumps(msg).encode("utf-8") err_msg = await self._request( @@ -410,9 +424,11 @@ async def _request(self, subject, payload, timeout, timeout_retries=5): res = await self.broker_manager.request(subject, payload, timeout=timeout) return res except Exception as e: - if 'timeout' not in str(e).lower() or timeout_retries <= 0: + if "timeout" not in str(e).lower() or timeout_retries <= 0: raise MemphisError(str(e)) from e - return await self._request(subject, payload, timeout=timeout, timeout_retries=timeout_retries-1) + return await self._request( + subject, payload, timeout=timeout, timeout_retries=timeout_retries - 1 + ) async def producer( self, @@ -433,17 +449,23 @@ async def producer( """ try: if not self.is_connection_active: - raise MemphisError("Connection is dead") + raise MemphisErrors.DeadConnection if not isinstance(station_name, str) and not isinstance(station_name, list): - raise MemphisError("station_name should be either string or list of strings") + raise MemphisErrors.InvalidStationNameTpye real_name = producer_name.lower() if generate_random_suffix: - warnings.warn("Deprecation warning: generate_random_suffix will be stopped to be supported after November 1'st, 2023.") + warnings.warn( + "Deprecation warning: generate_random_suffix will be stopped to be supported after November 1'st, 2023." + ) producer_name = self.__generate_random_suffix(producer_name) if isinstance(station_name, str): - return await self._single_station_producer(station_name, producer_name, real_name, timeout_retries) + return await self._single_station_producer( + station_name, producer_name, real_name, timeout_retries + ) else: - return await self._multi_station_producer(station_name, producer_name, real_name) + return await self._multi_station_producer( + station_name, producer_name, real_name + ) except Exception as e: raise MemphisError(str(e)) from e @@ -469,13 +491,16 @@ async def _single_station_producer( "req_version": 4, "username": self.username, "app_id": app_id, - "sdk_lang": "python" + "sdk_lang": "python", } - create_producer_req_bytes = json.dumps(create_producer_req, indent=2).encode( - "utf-8" - ) + create_producer_req_bytes = json.dumps( + create_producer_req, indent=2 + ).encode("utf-8") create_res = await self._request( - "$memphis_producer_creations", create_producer_req_bytes, 20, timeout_retries + "$memphis_producer_creations", + create_producer_req_bytes, + 20, + timeout_retries, ) create_res = create_res.data.decode("utf-8") create_res = json.loads(create_res) @@ -484,9 +509,9 @@ async def _single_station_producer( if "partitions_update" in create_res: if create_res["partitions_update"]["partitions_list"] is not None: - self.partition_producers_updates_data[internal_station_name] = create_res[ - "partitions_update" - ] + self.partition_producers_updates_data[ + internal_station_name + ] = create_res["partitions_update"] self.station_schemaverse_to_dls[internal_station_name] = create_res[ "schemaverse_to_dls" @@ -502,7 +527,10 @@ async def _single_station_producer( if "station_version" in create_res: if create_res["station_version"] >= 2: - await self.start_listen_for_functions_updates(internal_station_name, create_res["station_partitions_first_functions"]) + await self.start_listen_for_functions_updates( + internal_station_name, + create_res["station_partitions_first_functions"], + ) producer = Producer(self, producer_name, station_name, real_name) map_key = internal_station_name + "_" + real_name @@ -512,45 +540,32 @@ async def _single_station_producer( except Exception as e: raise MemphisError(str(e)) from e - async def _multi_station_producer( - self, - station_names: List[str], - producer_name: str, - real_name: str + self, station_names: List[str], producer_name: str, real_name: str ): return Producer(self, producer_name, station_names, real_name) def update_schema_data(self, station_name): internal_station_name = get_internal_name(station_name) if self.schema_updates_data[internal_station_name] != {}: - if ( - self.schema_updates_data[internal_station_name]["type"] - == "protobuf" - ): + if self.schema_updates_data[internal_station_name]["type"] == "protobuf": self.parse_descriptor(internal_station_name) if self.schema_updates_data[internal_station_name]["type"] == "json": schema = self.schema_updates_data[internal_station_name][ "active_version" ]["schema_content"] - self.json_schemas[internal_station_name] = json.loads( - schema) - elif ( - self.schema_updates_data[internal_station_name]["type"] == "graphql" - ): + self.json_schemas[internal_station_name] = json.loads(schema) + elif self.schema_updates_data[internal_station_name]["type"] == "graphql": self.graphql_schemas[internal_station_name] = build_graphql_schema( - self.schema_updates_data[internal_station_name][ - "active_version" - ]["schema_content"] + self.schema_updates_data[internal_station_name]["active_version"][ + "schema_content" + ] ) - elif ( - self.schema_updates_data[internal_station_name]["type"] == "avro" - ): + elif self.schema_updates_data[internal_station_name]["type"] == "avro": schema = self.schema_updates_data[internal_station_name][ "active_version" ]["schema_content"] - self.avro_schemas[internal_station_name] = json.loads( - schema) + self.avro_schemas[internal_station_name] = json.loads(schema) async def get_msg_schema_updates(self, internal_station_name, iterable): async for msg in iterable: @@ -591,7 +606,7 @@ def parse_descriptor(self, station_name): raise MemphisError(str(e)) from e async def start_listen_for_functions_updates(self, station_name, first_functions): - #first_functions should contain the dict of the first function of each partition key: partition number, value: first function id + # first_functions should contain the dict of the first function of each partition key: partition number, value: first function id if station_name in self.functions_updates_subs: self.functions_clients_per_station[station_name] += 1 @@ -682,28 +697,25 @@ async def consumer( """ try: if not self.is_connection_active: - raise MemphisError("Connection is dead") + raise MemphisErrors.DeadConnection if batch_size > self.MAX_BATCH_SIZE or batch_size < 1: - raise MemphisError( - f"Batch size can not be greater than {self.MAX_BATCH_SIZE} or less than 1") + raise MemphisErrors.InvalidBatchSize real_name = consumer_name.lower() if generate_random_suffix: - warnings.warn("Deprecation warning: generate_random_suffix will be stopped to be supported after November 1'st, 2023.") + warnings.warn( + "Deprecation warning: generate_random_suffix will be stopped to be supported after November 1'st, 2023." + ) consumer_name = self.__generate_random_suffix(consumer_name) cg = consumer_name if not consumer_group else consumer_group if start_consume_from_sequence <= 0: - raise MemphisError( - "start_consume_from_sequence has to be a positive number" - ) + raise MemphisErrors.NonPositiveStartConsumeFromSeq if last_messages < -1: - raise MemphisError("min value for last_messages is -1") + raise MemphisErrors.InvalidMinLasMessagesVal if start_consume_from_sequence > 1 and last_messages > -1: - raise MemphisError( - "Consumer creation options can't contain both start_consume_from_sequence and last_messages" - ) + raise MemphisErrors.ContainsStartConsumeAndLastMessages create_consumer_req = { "name": consumer_name, "station_name": station_name, @@ -717,14 +729,17 @@ async def consumer( "req_version": 4, "username": self.username, "app_id": app_id, - "sdk_lang":"python" + "sdk_lang": "python", } - create_consumer_req_bytes = json.dumps(create_consumer_req, indent=2).encode( - "utf-8" - ) + create_consumer_req_bytes = json.dumps( + create_consumer_req, indent=2 + ).encode("utf-8") creation_res = await self._request( - "$memphis_consumer_creations", create_consumer_req_bytes, 20, timeout_retries + "$memphis_consumer_creations", + create_consumer_req_bytes, + 20, + timeout_retries, ) creation_res = creation_res.data.decode("utf-8") if creation_res != "": @@ -736,7 +751,9 @@ async def consumer( internal_station_name = get_internal_name(station_name) if creation_res["partitions_update"]["partitions_list"] is not None: - self.partition_consumers_updates_data[internal_station_name] = creation_res["partitions_update"] + self.partition_consumers_updates_data[ + internal_station_name + ] = creation_res["partitions_update"] except: raise MemphisError(creation_res) @@ -745,22 +762,31 @@ async def consumer( partition_generator = None if inner_station_name in self.partition_consumers_updates_data: - partition_generator = PartitionGenerator(self.partition_consumers_updates_data[inner_station_name]["partitions_list"]) + partition_generator = PartitionGenerator( + self.partition_consumers_updates_data[inner_station_name][ + "partitions_list" + ] + ) consumer_group = get_internal_name(cg.lower()) subscriptions = {} if inner_station_name not in self.partition_consumers_updates_data: subject = inner_station_name + ".final" - psub = await self.broker_connection.pull_subscribe(subject, durable=consumer_group) + psub = await self.broker_connection.pull_subscribe( + subject, durable=consumer_group + ) subscriptions[1] = psub else: - for p in self.partition_consumers_updates_data[inner_station_name]["partitions_list"]: + for p in self.partition_consumers_updates_data[inner_station_name][ + "partitions_list" + ]: subject = f"{inner_station_name}${str(p)}.final" - psub = await self.broker_connection.pull_subscribe(subject, durable=consumer_group) + psub = await self.broker_connection.pull_subscribe( + subject, durable=consumer_group + ) subscriptions[p] = psub - internal_station_name = get_internal_name(station_name) map_key = internal_station_name + "_" + real_name if "schema_update" in creation_res: @@ -799,7 +825,7 @@ async def produce( async_produce: bool = False, msg_id: Union[str, None] = None, producer_partition_key: Union[str, None] = None, - producer_partition_number: Union[int, -1] = -1 + producer_partition_number: Union[int, -1] = -1, ): """Produces a message into a station without the need to create a producer. Args: @@ -818,15 +844,36 @@ async def produce( """ try: if not isinstance(station_name, str) and not isinstance(station_name, list): - raise MemphisError("station_name should be either string or list of strings") + raise MemphisErrors.InvalidStationNameTpye if isinstance(station_name, str): - await self._single_station_produce(station_name, producer_name, message, generate_random_suffix, ack_wait_sec, headers, async_produce, msg_id, producer_partition_key, producer_partition_number) + await self._single_station_produce( + station_name, + producer_name, + message, + generate_random_suffix, + ack_wait_sec, + headers, + async_produce, + msg_id, + producer_partition_key, + producer_partition_number, + ) else: - await self._multi_station_produce(station_name, producer_name, message, generate_random_suffix, ack_wait_sec, headers, async_produce, msg_id, producer_partition_key, producer_partition_number) + await self._multi_station_produce( + station_name, + producer_name, + message, + generate_random_suffix, + ack_wait_sec, + headers, + async_produce, + msg_id, + producer_partition_key, + producer_partition_number, + ) except Exception as e: raise MemphisError(str(e)) from e - async def _single_station_produce( self, station_name: str, @@ -838,7 +885,7 @@ async def _single_station_produce( async_produce: bool = False, msg_id: Union[str, None] = None, producer_partition_key: Union[str, None] = None, - producer_partition_number: Union[int, -1] = -1 + producer_partition_number: Union[int, -1] = -1, ): try: internal_station_name = get_internal_name(station_name) @@ -859,7 +906,7 @@ async def _single_station_produce( async_produce=async_produce, msg_id=msg_id, producer_partition_key=producer_partition_key, - producer_partition_number=producer_partition_number + producer_partition_number=producer_partition_number, ) except Exception as e: raise MemphisError(str(e)) from e @@ -875,7 +922,7 @@ async def _multi_station_produce( async_produce: bool = False, msg_id: Union[str, None] = None, producer_partition_key: Union[str, None] = None, - producer_partition_number: Union[int, -1] = -1 + producer_partition_number: Union[int, -1] = -1, ): try: producer = await self.producer( @@ -890,12 +937,11 @@ async def _multi_station_produce( async_produce=async_produce, msg_id=msg_id, producer_partition_key=producer_partition_key, - producer_partition_number=producer_partition_number + producer_partition_number=producer_partition_number, ) except Exception as e: raise MemphisError(str(e)) from e - async def fetch_messages( self, station_name: str, @@ -933,11 +979,11 @@ async def fetch_messages( try: consumer = None if not self.is_connection_active: - raise MemphisError( - "Cant fetch messages without being connected!") + raise MemphisError("Cant fetch messages without being connected!") if batch_size > self.MAX_BATCH_SIZE or batch_size < 1: raise MemphisError( - f"Batch size can not be greater than {self.MAX_BATCH_SIZE} or less than 1") + f"Batch size can not be greater than {self.MAX_BATCH_SIZE} or less than 1" + ) internal_station_name = get_internal_name(station_name) consumer_map_key = internal_station_name + "_" + consumer_name.lower() if consumer_map_key in self.consumers_map: @@ -948,22 +994,30 @@ async def fetch_messages( consumer_name=consumer_name, consumer_group=consumer_group, batch_size=batch_size, - batch_max_time_to_wait_ms=batch_max_time_to_wait_ms if batch_max_time_to_wait_ms >= 100 else 100, + batch_max_time_to_wait_ms=batch_max_time_to_wait_ms + if batch_max_time_to_wait_ms >= 100 + else 100, max_ack_time_ms=max_ack_time_ms, max_msg_deliveries=max_msg_deliveries, generate_random_suffix=generate_random_suffix, start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages, ) - messages = await consumer.fetch(batch_size, consumer_partition_key=consumer_partition_key, consumer_partition_number=consumer_partition_number, prefetch=prefetch) + messages = await consumer.fetch( + batch_size, + consumer_partition_key=consumer_partition_key, + consumer_partition_number=consumer_partition_number, + prefetch=prefetch, + ) if messages == None: messages = [] return messages except Exception as e: raise MemphisError(str(e)) from e - async def create_schema(self, schema_name, schema_type, schema_path, timeout_retries=5): - + async def create_schema( + self, schema_name, schema_type, schema_path, timeout_retries=5 + ): """Creates a new schema. In case schema is already exist a new version will be created Args:. schema_name (str): the name of the schema. @@ -971,15 +1025,14 @@ async def create_schema(self, schema_name, schema_type, schema_path, timeout_ret schema_path (str): the path for the schema file """ - if schema_type not in {'json', 'graphql', 'protobuf', 'avro'}: - raise MemphisError("schema type not supported" + type) + if schema_type not in {"json", "graphql", "protobuf", "avro"}: + raise MemphisErrors.invalid_schema_type(schema_type) try: await self.schema_name_validation(schema_name) except Exception as e: raise e - schema_content = "" with open(schema_path, "rt", encoding="utf-8") as f: schema_content = f.read() @@ -989,32 +1042,34 @@ async def create_schema(self, schema_name, schema_type, schema_path, timeout_ret "type": schema_type, "created_by_username": self.username, "schema_content": schema_content, - "message_struct_name": "" + "message_struct_name": "", } - create_schema_req_bytes = json.dumps(create_schema_req, indent=2).encode("utf-8") + create_schema_req_bytes = json.dumps(create_schema_req, indent=2).encode( + "utf-8" + ) create_res = await self._request( - "$memphis_schema_creations", create_schema_req_bytes, 20, timeout_retries) + "$memphis_schema_creations", create_schema_req_bytes, 20, timeout_retries + ) create_res = create_res.data.decode("utf-8") create_res = json.loads(create_res) - if create_res["error"] != "" and not "already exists" in create_res["error"] : + if create_res["error"] != "" and not "already exists" in create_res["error"]: raise MemphisError(create_res["error"]) - async def schema_name_validation(self, schema_name): if len(schema_name) == 0: - raise MemphisError("Schema name cannot be empty") + raise MemphisErrors.MissingSchemaName if len(schema_name) > 128: - raise MemphisError("Schema name should be under 128 characters") + raise MemphisErrors.SchemaNameTooLong - if re.fullmatch(r'^[a-z0-9_.-]*$', schema_name) is None: - raise MemphisError("Only alphanumeric and the '_', '-', '.' characters are allowed in the schema name") + if re.fullmatch(r"^[a-z0-9_.-]*$", schema_name) is None: + raise MemphisErrors.InvalidSchemaChars if not schema_name[0].isalnum() or not schema_name[-1].isalnum(): - raise MemphisError("Schema name cannot start or end with a non-alphanumeric character") + raise MemphisErrors.InvalidSchemaStartChar def is_connected(self): return self.broker_manager.is_connected @@ -1035,7 +1090,8 @@ def unset_cached_consumer_station(self, station_name): for key in list(self.consumers_map): consumer = self.consumers_map[key] consumer_station_name_internal = get_internal_name( - consumer.station_name) + consumer.station_name + ) if consumer_station_name_internal == internal_station_name: del self.consumers_map[key] except Exception as e: diff --git a/memphis/message.py b/memphis/message.py index 3ad94c1..b50172d 100644 --- a/memphis/message.py +++ b/memphis/message.py @@ -2,11 +2,18 @@ import json -from memphis.exceptions import MemphisConnectError, MemphisError, MemphisSchemaError +from memphis.exceptions import ( + MemphisConnectError, + MemphisError, + MemphisErrors, +) from memphis.station import Station + class Message: - def __init__(self, message, connection, cg_name, internal_station_name, partition = 0): + def __init__( + self, message, connection, cg_name, internal_station_name, partition=0 + ): self.message = message self.connection = connection self.cg_name = cg_name @@ -42,7 +49,7 @@ async def nack(self): """ nack - not ack for a message, meaning that the message will be redelivered again to the same consumers group without waiting to its ack wait time. """ - if not hasattr(self.message, 'nak'): + if not hasattr(self.message, "nak"): return await self.message.nak() @@ -52,7 +59,7 @@ async def dead_letter(self, reason: str): The message will still be available to other consumer groups """ try: - if not hasattr(self.message, 'term'): + if not hasattr(self.message, "term"): return await self.message.term() md = self.message.metadata() @@ -80,14 +87,18 @@ def get_data(self): async def get_data_deserialized(self): """Receive the message.""" try: - if self.connection.schema_updates_data and self.connection.schema_updates_data[self.internal_station_name] != {}: + if ( + self.connection.schema_updates_data + and self.connection.schema_updates_data[self.internal_station_name] + != {} + ): schema_type = self.connection.schema_updates_data[ self.internal_station_name ]["type"] try: await self.station.validate_msg(bytearray(self.message.data)) except Exception as e: - raise MemphisSchemaError("Deserialization has been failed since the message format does not align with the currently attached schema: " + str(e)) + raise MemphisErrors.schema_msg_mismatch(e) if schema_type == "protobuf": proto_msg = self.connection.proto_msgs[self.internal_station_name] proto_msg.ParseFromString(self.message.data) @@ -133,7 +144,7 @@ async def delay(self, delay): "$memphis_pm_id" in self.message.headers and "$memphis_pm_cg_name" in self.message.headers ): - raise MemphisError("cannot delay DLS message") + raise MemphisErrors.DLSCannotBeDelayed try: await self.message.nak(delay=delay) except Exception as e: diff --git a/memphis/partition_generator.py b/memphis/partition_generator.py index 57b9ad3..127671e 100644 --- a/memphis/partition_generator.py +++ b/memphis/partition_generator.py @@ -1,5 +1,6 @@ -#The PartitionGenerator class is used to create a round robin generator for station's partitions -#the class gets a list of partitions and by using the next() function it returns an item from the list +# The PartitionGenerator class is used to create a round robin generator for station's partitions +# the class gets a list of partitions and by using the next() function it returns an item from the list + class PartitionGenerator: def __init__(self, partitions_list): @@ -11,4 +12,3 @@ def __next__(self): partition_to_return = self.partitions_list[self.current] self.current = (self.current + 1) % self.length return partition_to_return - \ No newline at end of file diff --git a/memphis/producer.py b/memphis/producer.py index 55936d5..de1f983 100644 --- a/memphis/producer.py +++ b/memphis/producer.py @@ -7,7 +7,7 @@ import warnings import mmh3 -from memphis.exceptions import MemphisError, MemphisSchemaError +from memphis.exceptions import MemphisError, MemphisSchemaError, MemphisErrors from memphis.headers import Headers from memphis.utils import get_internal_name from memphis.partition_generator import PartitionGenerator @@ -18,7 +18,11 @@ class Producer: def __init__( - self, connection, producer_name: str, station_name: Union[str, List[str]] , real_name: str + self, + connection, + producer_name: str, + station_name: Union[str, List[str]], + real_name: str, ): self.connection = connection self.producer_name = producer_name.lower() @@ -35,7 +39,11 @@ def __init__( self.internal_station_name = get_internal_name(self.station_name) self.loop = asyncio.get_running_loop() if self.internal_station_name in connection.partition_producers_updates_data: - self.partition_generator = PartitionGenerator(connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) + self.partition_generator = PartitionGenerator( + connection.partition_producers_updates_data[self.internal_station_name][ + "partitions_list" + ] + ) # pylint: disable=too-many-arguments async def produce( @@ -48,7 +56,7 @@ async def produce( msg_id: Union[str, None] = None, concurrent_task_limit: Union[int, None] = None, producer_partition_key: Union[str, None] = None, - producer_partition_number: Union[int, -1] = -1 + producer_partition_number: Union[int, -1] = -1, ): """Produces a message into a station. Args: @@ -85,7 +93,7 @@ async def produce( async_produce=async_produce, msg_id=msg_id, producer_partition_key=producer_partition_key, - producer_partition_number=producer_partition_number + producer_partition_number=producer_partition_number, ) else: await self._single_station_produce( @@ -97,7 +105,7 @@ async def produce( msg_id=msg_id, concurrent_task_limit=concurrent_task_limit, producer_partition_key=producer_partition_key, - producer_partition_number=producer_partition_number + producer_partition_number=producer_partition_number, ) async def _single_station_produce( @@ -110,7 +118,7 @@ async def _single_station_produce( msg_id: Union[str, None] = None, concurrent_task_limit: Union[int, None] = None, producer_partition_key: Union[str, None] = None, - producer_partition_number: Union[int, -1] = -1 + producer_partition_number: Union[int, -1] = -1, ): """Produces a message into a station. Args: @@ -156,24 +164,45 @@ async def _single_station_produce( else: headers = memphis_headers - if self.internal_station_name not in self.connection.partition_producers_updates_data: + if ( + self.internal_station_name + not in self.connection.partition_producers_updates_data + ): partition_name = self.internal_station_name - elif len(self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list']) == 1: + elif ( + len( + self.connection.partition_producers_updates_data[ + self.internal_station_name + ]["partitions_list"] + ) + == 1 + ): partition_name = f"{self.internal_station_name}${self.connection.partition_producers_updates_data[self.internal_station_name]['partitions_list'][0]}" elif producer_partition_number > 0 and producer_partition_key is not None: - raise MemphisError('Can not use both partition number and partition key') + raise MemphisError( + "Can not use both partition number and partition key" + ) elif producer_partition_key is not None: partition_number = self.get_partition_from_key(producer_partition_key) partition_name = f"{self.internal_station_name}${str(partition_number)}" elif producer_partition_number > 0: - self.validate_partition_number(producer_partition_number, self.internal_station_name) - partition_name = f"{self.internal_station_name}${str(producer_partition_number)}" + self.validate_partition_number( + producer_partition_number, self.internal_station_name + ) + partition_name = ( + f"{self.internal_station_name}${str(producer_partition_number)}" + ) else: partition_name = f"{self.internal_station_name}${str(next(self.partition_generator))}" if self.internal_station_name in self.connection.functions_updates_data: partition_number = partition_name.split("$")[1] - if partition_number in self.connection.functions_updates_data[self.internal_station_name]: + if ( + partition_number + in self.connection.functions_updates_data[ + self.internal_station_name + ] + ): full_subject_name = f"{partition_name}.functions.{self.connection.functions_updates_data[self.internal_station_name][partition_number]}" else: full_subject_name = f"{partition_name}.final" @@ -182,25 +211,29 @@ async def _single_station_produce( if async_produce: nonblocking = True - warnings.warn("The argument async_produce is deprecated. " + \ - "Please use the argument nonblocking instead.") + warnings.warn( + "The argument async_produce is deprecated. " + + "Please use the argument nonblocking instead." + ) if nonblocking: try: task = self.loop.create_task( - self.connection.broker_connection.publish( - full_subject_name, - message, - timeout=ack_wait_sec, - headers=headers, - ) - ) + self.connection.broker_connection.publish( + full_subject_name, + message, + timeout=ack_wait_sec, + headers=headers, + ) + ) self.background_tasks.add(task) task.add_done_callback(self.background_tasks.discard) # block until the number of outstanding async tasks is reduced - if concurrent_task_limit is not None and \ - len(self.background_tasks) >= concurrent_task_limit: + if ( + concurrent_task_limit is not None + and len(self.background_tasks) >= concurrent_task_limit + ): desired_size = concurrent_task_limit / 2 while len(self.background_tasks) > desired_size: await asyncio.sleep(0.1) @@ -254,14 +287,14 @@ async def _single_station_produce( "data": msg_hex, "headers": headers, }, - "validation_error": str(e) + "validation_error": str(e), } buf = json.dumps(buf).encode("utf-8") - await self.connection.broker_manager.publish("$memphis_schemaverse_dls", buf) + await self.connection.broker_manager.publish( + "$memphis_schemaverse_dls", buf + ) - if self.connection.cluster_configurations.get( - "send_notification" - ): + if self.connection.cluster_configurations.get("send_notification"): await self.connection.send_notification( "Schema validation has failed", "Station: " @@ -274,7 +307,7 @@ async def _single_station_produce( schemaverse_fail_alert_type, ) raise e - except Exception as e: # pylint: disable-next=no-member + except Exception as e: # pylint: disable-next=no-member if hasattr(e, "status_code") and e.status_code == "503": raise MemphisError( "Produce operation has failed, please check whether Station/Producer still exist" @@ -289,7 +322,7 @@ async def _multi_station_produce( async_produce: Union[bool, None] = None, msg_id: Union[str, None] = None, producer_partition_key: Union[str, None] = None, - producer_partition_number: Union[int, -1] = -1 + producer_partition_number: Union[int, -1] = -1, ): for sn in self.station_name: await self.connection.produce( @@ -301,7 +334,7 @@ async def _multi_station_produce( async_produce=async_produce, msg_id=msg_id, producer_partition_key=producer_partition_key, - producer_partition_number=producer_partition_number + producer_partition_number=producer_partition_number, ) # pylint: enable=too-many-arguments @@ -339,16 +372,12 @@ async def _destroy_single_station_producer(self, timeout_retries=5): internal_station_name = get_internal_name(self.station_name) clients_number = ( - self.connection.clients_per_station.get( - internal_station_name) - 1 + self.connection.clients_per_station.get(internal_station_name) - 1 ) - self.connection.clients_per_station[ - internal_station_name - ] = clients_number + self.connection.clients_per_station[internal_station_name] = clients_number if clients_number == 0: - sub = self.connection.schema_updates_subs.get( - internal_station_name) + sub = self.connection.schema_updates_subs.get(internal_station_name) task = self.connection.schema_tasks.get(internal_station_name) if internal_station_name in self.connection.schema_updates_data: del self.connection.schema_updates_data[internal_station_name] @@ -361,13 +390,17 @@ async def _destroy_single_station_producer(self, timeout_retries=5): if sub is not None: await sub.unsubscribe() - self.connection.functions_clients_per_station[internal_station_name] -= 1 - if self.connection.functions_clients_per_station[internal_station_name] == 0: + if ( + self.connection.functions_clients_per_station[internal_station_name] + == 0 + ): if internal_station_name in self.connection.functions_updates_data: del self.connection.functions_updates_data[internal_station_name] if internal_station_name in self.connection.functions_updates_subs: - sub = self.connection.functions_updates_subs.get(internal_station_name) + sub = self.connection.functions_updates_subs.get( + internal_station_name + ) if sub is not None: await sub.unsubscribe() del self.connection.functions_updates_subs[internal_station_name] @@ -377,7 +410,6 @@ async def _destroy_single_station_producer(self, timeout_retries=5): task.cancel() del self.connection.functions_tasks[internal_station_name] - map_key = internal_station_name + "_" + self.real_name del self.connection.producers_map[map_key] @@ -385,27 +417,44 @@ async def _destroy_single_station_producer(self, timeout_retries=5): raise Exception(e) async def _destroy_multi_station_producer(self, timeout_retries=5): - internal_station_name_list = [get_internal_name(station_name) for station_name in self.station_name] - producer_keys = [f"{internal_station_name}_{self.real_name}" for internal_station_name in internal_station_name_list] - producers = [self.connection.producers_map.get(producer_key) for producer_key in producer_keys] + internal_station_name_list = [ + get_internal_name(station_name) for station_name in self.station_name + ] + producer_keys = [ + f"{internal_station_name}_{self.real_name}" + for internal_station_name in internal_station_name_list + ] + producers = [ + self.connection.producers_map.get(producer_key) + for producer_key in producer_keys + ] producers = [producer for producer in producers if producer is not None] for producer in producers: await producer.destroy(timeout_retries) - def get_partition_from_key(self, key): try: - index = mmh3.hash(key, self.connection.SEED, signed=False) % len(self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"]) - return self.connection.partition_producers_updates_data[self.internal_station_name]["partitions_list"][index] + index = mmh3.hash(key, self.connection.SEED, signed=False) % len( + self.connection.partition_producers_updates_data[ + self.internal_station_name + ]["partitions_list"] + ) + return self.connection.partition_producers_updates_data[ + self.internal_station_name + ]["partitions_list"][index] except Exception as e: raise e def validate_partition_number(self, partition_number, station_name): - partitions_list = self.connection.partition_producers_updates_data[station_name]["partitions_list"] + partitions_list = self.connection.partition_producers_updates_data[ + station_name + ]["partitions_list"] if partitions_list is not None: if partition_number < 1 or partition_number > len(partitions_list): - raise MemphisError("Partition number is out of range") + raise MemphisErrors.PartitionOutOfRange elif partition_number not in partitions_list: - raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + raise MemphisErrors.partition_not_in_station( + partition_number, station_name + ) else: - raise MemphisError(f"Partition {str(partition_number)} does not exist in station {station_name}") + raise MemphisErrors.partition_not_in_station(partition_number, station_name) diff --git a/memphis/station.py b/memphis/station.py index 36f29e5..9565ab8 100644 --- a/memphis/station.py +++ b/memphis/station.py @@ -6,7 +6,7 @@ import google.protobuf.json_format as protobuf_json_format import fastavro -from memphis.exceptions import MemphisError, MemphisSchemaError +from memphis.exceptions import MemphisError, MemphisSchemaError, MemphisErrors from memphis.utils import get_internal_name @@ -48,7 +48,7 @@ async def validate_msg(self, message): msg_to_send = message.SerializeToString() return msg_to_send elif not isinstance(message, bytearray) and not isinstance(message, dict): - raise MemphisSchemaError("Unsupported message type") + raise MemphisErrors.UnsupportedMsgType else: if isinstance(message, dict): message = bytearray(json.dumps(message).encode("utf-8")) @@ -89,10 +89,10 @@ def validate_protobuf(self, message): except Exception as e: raise MemphisSchemaError(str(e)) else: - raise MemphisSchemaError("Unsupported message type") + raise MemphisErrors.UnsupportedMsgType except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) + raise MemphisErrors.schema_validation_failed(e) def validate_json_schema(self, message): try: @@ -100,12 +100,12 @@ def validate_json_schema(self, message): try: message_obj = json.loads(message) except Exception as e: - raise Exception("Expecting Json format: " + str(e)) + raise MemphisErrors.expecting_format(e, "JSON") elif isinstance(message, dict): message_obj = message message = bytearray(json.dumps(message_obj).encode("utf-8")) else: - raise Exception("Unsupported message type") + raise MemphisErrors.UnsupportedMsgType validate( instance=message_obj, @@ -113,7 +113,7 @@ def validate_json_schema(self, message): ) return message except Exception as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) + raise MemphisErrors.schema_validation_failed(e) def validate_graphql(self, message): try: @@ -128,19 +128,18 @@ def validate_graphql(self, message): message = str(msg.loc.source.body) message = message.encode("utf-8") else: - raise Exception("Unsupported message type") + raise MemphisErrors.UnsupportedMsgType validate_res = validate_graphql( schema=self.connection.graphql_schemas[self.internal_station_name], document_ast=msg, ) if len(validate_res) > 0: - raise Exception( - "Schema validation has failed: " + str(validate_res)) + raise MemphisErrors.schema_validation_failed(validate_res) return message except Exception as e: if "Syntax Error" in str(e): e = "Invalid message format, expected GraphQL" - raise MemphisSchemaError("Schema validation has failed: " + str(e)) + raise MemphisErrors.schema_validation_failed(e) def validate_avro_schema(self, message): try: @@ -148,12 +147,12 @@ def validate_avro_schema(self, message): try: message_obj = json.loads(message) except Exception as e: - raise Exception("Expecting Avro format: " + str(e)) + raise MemphisErrors.expecting_format(e, "Avro") elif isinstance(message, dict): message_obj = message message = bytearray(json.dumps(message_obj).encode("utf-8")) else: - raise Exception("Unsupported message type") + raise MemphisErrors.UnsupportedMsgType fastavro.validate( message_obj, @@ -161,7 +160,7 @@ def validate_avro_schema(self, message): ) return message except fastavro.validation.ValidationError as e: - raise MemphisSchemaError("Schema validation has failed: " + str(e)) + raise MemphisErrors.schema_validation_failed(e) async def destroy(self, timeout_retries=5): """Destroy the station.""" @@ -193,13 +192,14 @@ async def destroy(self, timeout_retries=5): if sub is not None: await sub.unsubscribe() - if internal_station_name in self.connection.functions_clients_per_station: del self.connection.functions_clients_per_station[internal_station_name] if internal_station_name in self.connection.functions_updates_data: del self.connection.functions_updates_data[internal_station_name] if internal_station_name in self.connection.functions_updates_subs: - function_sub = self.connection.functions_updates_subs.get(internal_station_name) + function_sub = self.connection.functions_updates_subs.get( + internal_station_name + ) if function_sub is not None: await function_sub.unsubscribe() del self.connection.functions_updates_subs[internal_station_name]