From 7029a1843ff3a9713c1011f260e5e5db1db1cf8a Mon Sep 17 00:00:00 2001 From: priyal1508 <54278892+priyal1508@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:52:29 +0530 Subject: [PATCH 1/4] changes for adi based skillset --- aisearch-skillset/ai_search.py | 690 ++++++++++++++++++ aisearch-skillset/deploy.py | 80 ++ aisearch-skillset/environment.py | 192 +++++ aisearch-skillset/inquiry_document.py | 320 ++++++++ function_apps/common/ai_search.py | 127 ++++ function_apps/indexer/adi_2_aisearch.py | 460 ++++++++++++ function_apps/indexer/function_app.py | 296 ++++++++ .../indexer/key_phrase_extraction.py | 112 +++ .../indexer/pre_embedding_cleaner.py | 144 ++++ function_apps/indexer/requirements.txt | 26 + 10 files changed, 2447 insertions(+) create mode 100644 aisearch-skillset/ai_search.py create mode 100644 aisearch-skillset/deploy.py create mode 100644 aisearch-skillset/environment.py create mode 100644 aisearch-skillset/inquiry_document.py create mode 100644 function_apps/common/ai_search.py create mode 100644 function_apps/indexer/adi_2_aisearch.py create mode 100644 function_apps/indexer/function_app.py create mode 100644 function_apps/indexer/key_phrase_extraction.py create mode 100644 function_apps/indexer/pre_embedding_cleaner.py create mode 100644 function_apps/indexer/requirements.txt diff --git a/aisearch-skillset/ai_search.py b/aisearch-skillset/ai_search.py new file mode 100644 index 0000000..7573055 --- /dev/null +++ b/aisearch-skillset/ai_search.py @@ -0,0 +1,690 @@ +from abc import ABC, abstractmethod +from azure.search.documents.indexes.models import ( + SearchIndex, + SearchableField, + VectorSearch, + VectorSearchProfile, + HnswAlgorithmConfiguration, + SemanticSearch, + NativeBlobSoftDeleteDeletionDetectionPolicy, + HighWaterMarkChangeDetectionPolicy, + WebApiSkill, + CustomVectorizer, + CustomWebApiParameters, + SearchIndexer, + SearchIndexerSkillset, + SearchIndexerDataContainer, + SearchIndexerDataSourceConnection, + SearchIndexerDataSourceType, + SearchIndexerDataUserAssignedIdentity, + OutputFieldMappingEntry, + InputFieldMappingEntry, + SynonymMap, + DocumentExtractionSkill, + OcrSkill, + MergeSkill, + ConditionalSkill, + SplitSkill, +) +from azure.core.exceptions import HttpResponseError +from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient +from environment import ( + get_fq_blob_connection_string, + get_blob_container_name, + get_custom_skill_function_url, + get_managed_identity_fqname, + get_function_app_authresourceid, + IndexerType, +) + + +class AISearch(ABC): + def __init__( + self, + endpoint: str, + credential, + suffix: str | None = None, + rebuild: bool | None = False, + ): + """Initialize the AI search class + + Args: + endpoint (str): The search endpoint + credential (AzureKeyCredential): The search credential""" + self.indexer_type = None + + if rebuild is not None: + self.rebuild = rebuild + else: + self.rebuild = False + + if suffix is None: + self.suffix = "" + self.test = False + else: + self.suffix = f"-{suffix}-test" + self.test = True + + self._search_indexer_client = SearchIndexerClient(endpoint, credential) + self._search_index_client = SearchIndexClient(endpoint, credential) + + @property + def indexer_name(self): + return f"{str(self.indexer_type.value)}-indexer{self.suffix}" + + @property + def skillset_name(self): + return f"{str(self.indexer_type.value)}-skillset{self.suffix}" + + @property + def semantic_config_name(self): + return f"{str(self.indexer_type.value)}-semantic-config{self.suffix}" + + @property + def index_name(self): + return f"{str(self.indexer_type.value)}-index{self.suffix}" + + @property + def data_source_name(self): + blob_container_name = get_blob_container_name(self.indexer_type) + return f"{blob_container_name}-data-source{self.suffix}" + + @property + def vector_search_profile_name(self): + return ( + f"{str(self.indexer_type.value)}-compass-vector-search-profile{self.suffix}" + ) + + @abstractmethod + def get_index_fields(self) -> list[SearchableField]: + """Get the index fields for the indexer. + + Returns: + list[SearchableField]: The index fields""" + + @abstractmethod + def get_semantic_search(self) -> SemanticSearch: + """Get the semantic search configuration for the indexer. + + Returns: + SemanticSearch: The semantic search configuration""" + + @abstractmethod + def get_skills(self): + """Get the skillset for the indexer.""" + + @abstractmethod + def get_indexer(self) -> SearchIndexer: + """Get the indexer for the indexer.""" + + def get_index_projections(self): + """Get the index projections for the indexer.""" + return None + + def get_synonym_map_names(self): + return [] + + def get_user_assigned_managed_identity( + self, + ) -> SearchIndexerDataUserAssignedIdentity: + """Get user assigned managed identity details""" + + user_assigned_identity = SearchIndexerDataUserAssignedIdentity( + user_assigned_identity=get_managed_identity_fqname() + ) + return user_assigned_identity + + def get_data_source(self) -> SearchIndexerDataSourceConnection: + """Get the data source for the indexer.""" + + if self.indexer_type == IndexerType.BUSINESS_GLOSSARY: + data_deletion_detection_policy = None + else: + data_deletion_detection_policy = ( + NativeBlobSoftDeleteDeletionDetectionPolicy() + ) + + data_change_detection_policy = HighWaterMarkChangeDetectionPolicy( + high_water_mark_column_name="metadata_storage_last_modified" + ) + + container = SearchIndexerDataContainer( + name=get_blob_container_name(self.indexer_type) + ) + + data_source_connection = SearchIndexerDataSourceConnection( + name=self.data_source_name, + type=SearchIndexerDataSourceType.AZURE_BLOB, + connection_string=get_fq_blob_connection_string(), + container=container, + data_change_detection_policy=data_change_detection_policy, + data_deletion_detection_policy=data_deletion_detection_policy, + identity=self.get_user_assigned_managed_identity(), + ) + + return data_source_connection + + def get_compass_vector_custom_skill( + self, context, source, target_name="vector" + ) -> WebApiSkill: + """Get the custom skill for compass. + + Args: + ----- + context (str): The context of the skill + source (str): The source of the skill + target_name (str): The target name of the skill + + Returns: + -------- + WebApiSkill: The custom skill for compass""" + + if self.test: + batch_size = 2 + degree_of_parallelism = 2 + else: + batch_size = 4 + degree_of_parallelism = 8 + + embedding_skill_inputs = [ + InputFieldMappingEntry(name="text", source=source), + ] + embedding_skill_outputs = [ + OutputFieldMappingEntry(name="vector", target_name=target_name) + ] + # Limit the number of documents to be processed in parallel to avoid timing out on compass api + embedding_skill = WebApiSkill( + name="Compass Connector API", + description="Skill to generate embeddings via compass API connector", + context=context, + uri=get_custom_skill_function_url("compass"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + inputs=embedding_skill_inputs, + outputs=embedding_skill_outputs, + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return embedding_skill + + def get_pre_embedding_cleaner_skill( + self, context, source, chunk_by_page=False, target_name="cleaned_chunk" + ) -> WebApiSkill: + """Get the custom skill for data cleanup. + + Args: + ----- + context (str): The context of the skill + inputs (List[InputFieldMappingEntry]): The inputs of the skill + outputs (List[OutputFieldMappingEntry]): The outputs of the skill + + Returns: + -------- + WebApiSkill: The custom skill for data cleanup""" + + if self.test: + batch_size = 2 + degree_of_parallelism = 2 + else: + batch_size = 16 + degree_of_parallelism = 16 + + pre_embedding_cleaner_skill_inputs = [ + InputFieldMappingEntry(name="chunk", source=source) + ] + + pre_embedding_cleaner_skill_outputs = [ + OutputFieldMappingEntry(name="cleaned_chunk", target_name=target_name), + OutputFieldMappingEntry(name="chunk", target_name="chunk"), + OutputFieldMappingEntry(name="section", target_name="eachsection"), + ] + + if chunk_by_page: + pre_embedding_cleaner_skill_outputs.extend( + [ + OutputFieldMappingEntry(name="page_number", target_name="page_no"), + ] + ) + + pre_embedding_cleaner_skill = WebApiSkill( + name="Pre Embedding Cleaner Skill", + description="Skill to clean the data before sending to embedding", + context=context, + uri=get_custom_skill_function_url("pre_embedding_cleaner"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + inputs=pre_embedding_cleaner_skill_inputs, + outputs=pre_embedding_cleaner_skill_outputs, + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return pre_embedding_cleaner_skill + + def get_text_split_skill(self, context, source) -> SplitSkill: + """Get the skill for text split. + + Args: + ----- + context (str): The context of the skill + inputs (List[InputFieldMappingEntry]): The inputs of the skill + outputs (List[OutputFieldMappingEntry]): The outputs of the skill + + Returns: + -------- + splitSKill: The skill for text split""" + + text_split_skill = SplitSkill( + name="Text Split Skill", + description="Skill to split the text before sending to embedding", + context=context, + text_split_mode="pages", + maximum_page_length=2000, + page_overlap_length=500, + inputs=[InputFieldMappingEntry(name="text", source=source)], + outputs=[OutputFieldMappingEntry(name="textItems", target_name="pages")], + ) + + return text_split_skill + + def get_custom_text_split_skill( + self, + context, + source, + text_split_mode="semantic", + maximum_page_length=1000, + separator=" ", + initial_threshold=0.7, + appending_threshold=0.6, + merging_threshold=0.6, + ) -> WebApiSkill: + """Get the custom skill for text split. + + Args: + ----- + context (str): The context of the skill + inputs (List[InputFieldMappingEntry]): The inputs of the skill + outputs (List[OutputFieldMappingEntry]): The outputs of the skill + + Returns: + -------- + WebApiSkill: The custom skill for text split""" + + if self.test: + batch_size = 2 + degree_of_parallelism = 2 + else: + batch_size = 2 + degree_of_parallelism = 6 + + text_split_skill_inputs = [ + InputFieldMappingEntry(name="text", source=source), + ] + + headers = { + "text_split_mode": text_split_mode, + "maximum_page_length": maximum_page_length, + "separator": separator, + "initial_threshold": initial_threshold, + "appending_threshold": appending_threshold, + "merging_threshold": merging_threshold, + } + + text_split_skill = WebApiSkill( + name="Text Split Skill", + description="Skill to split the text before sending to embedding", + context=context, + uri=get_custom_skill_function_url("split"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + http_headers=headers, + inputs=text_split_skill_inputs, + outputs=[OutputFieldMappingEntry(name="chunks", target_name="pages")], + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return text_split_skill + + def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: + """Get the custom skill for adi. + + Returns: + -------- + WebApiSkill: The custom skill for adi""" + + if self.test: + batch_size = 1 + degree_of_parallelism = 4 + else: + batch_size = 1 + degree_of_parallelism = 16 + + if chunk_by_page: + output = [ + OutputFieldMappingEntry(name="extracted_content", target_name="pages") + ] + else: + output = [ + OutputFieldMappingEntry( + name="extracted_content", target_name="extracted_content" + ) + ] + + adi_skill = WebApiSkill( + name="ADI Skill", + description="Skill to generate ADI", + context="/document", + uri=get_custom_skill_function_url("adi"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + http_headers={"chunk_by_page": chunk_by_page}, + inputs=[ + InputFieldMappingEntry( + name="source", source="/document/metadata_storage_path" + ) + ], + outputs=output, + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return adi_skill + + def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill: + """Get the key phrase extraction skill. + + Args: + ----- + context (str): The context of the skill + source (str): The source of the skill + + Returns: + -------- + WebApiSkill: The key phrase extraction skill""" + + if self.test: + batch_size = 4 + degree_of_parallelism = 4 + else: + batch_size = 16 + degree_of_parallelism = 16 + + keyphrase_extraction_skill_inputs = [ + InputFieldMappingEntry(name="text", source=source), + ] + keyphrase_extraction__skill_outputs = [ + OutputFieldMappingEntry(name="keyPhrases", target_name="keywords") + ] + key_phrase_extraction_skill = WebApiSkill( + name="Key phrase extraction API", + description="Skill to extract keyphrases", + context=context, + uri=get_custom_skill_function_url("keyphraseextraction"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + inputs=keyphrase_extraction_skill_inputs, + outputs=keyphrase_extraction__skill_outputs, + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return key_phrase_extraction_skill + + def get_document_extraction_skill(self, context, source) -> DocumentExtractionSkill: + """Get the document extraction utility skill. + + Args: + ----- + context (str): The context of the skill + source (str): The source of the skill + + Returns: + -------- + DocumentExtractionSkill: The document extraction utility skill""" + + doc_extraction_skill = DocumentExtractionSkill( + description="Extraction skill to extract content from office docs like excel, ppt, doc etc", + context=context, + inputs=[InputFieldMappingEntry(name="file_data", source=source)], + outputs=[ + OutputFieldMappingEntry( + name="content", target_name="extracted_content" + ), + OutputFieldMappingEntry( + name="normalized_images", target_name="extracted_normalized_images" + ), + ], + ) + + return doc_extraction_skill + + def get_ocr_skill(self, context, source) -> OcrSkill: + """Get the ocr utility skill + Args: + ----- + context (str): The context of the skill + source (str): The source of the skill + + Returns: + -------- + OcrSkill: The ocr skill""" + + if self.test: + batch_size = 2 + degree_of_parallelism = 2 + else: + batch_size = 2 + degree_of_parallelism = 2 + + ocr_skill_inputs = [ + InputFieldMappingEntry(name="image", source=source), + ] + ocr__skill_outputs = [OutputFieldMappingEntry(name="text", target_name="text")] + ocr_skill = WebApiSkill( + name="ocr API", + description="Skill to extract text from images", + context=context, + uri=get_custom_skill_function_url("ocr"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + inputs=ocr_skill_inputs, + outputs=ocr__skill_outputs, + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return ocr_skill + + def get_merge_skill(self, context, source) -> MergeSkill: + """Get the merge + Args: + ----- + context (str): The context of the skill + source (array): The source of the skill + + Returns: + -------- + mergeSkill: The merge skill""" + + merge_skill = MergeSkill( + description="Merge skill for combining OCR'd and regular text", + context=context, + inputs=[ + InputFieldMappingEntry(name="text", source=source[0]), + InputFieldMappingEntry(name="itemsToInsert", source=source[1]), + InputFieldMappingEntry(name="offsets", source=source[2]), + ], + outputs=[ + OutputFieldMappingEntry(name="mergedText", target_name="merged_content") + ], + ) + + return merge_skill + + def get_conditional_skill(self, context, source) -> ConditionalSkill: + """Get the merge + Args: + ----- + context (str): The context of the skill + source (array): The source of the skill + + Returns: + -------- + ConditionalSkill: The conditional skill""" + + conditional_skill = ConditionalSkill( + description="Select between OCR and Document Extraction output", + context=context, + inputs=[ + InputFieldMappingEntry(name="condition", source=source[0]), + InputFieldMappingEntry(name="whenTrue", source=source[1]), + InputFieldMappingEntry(name="whenFalse", source=source[2]), + ], + outputs=[ + OutputFieldMappingEntry(name="output", target_name="updated_content") + ], + ) + + return conditional_skill + + def get_compass_vector_search(self) -> VectorSearch: + """Get the vector search configuration for compass. + + Args: + indexer_type (str): The type of the indexer + + Returns: + VectorSearch: The vector search configuration + """ + vectorizer_name = ( + f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}" + ) + algorithim_name = f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}" + + vector_search = VectorSearch( + algorithms=[ + HnswAlgorithmConfiguration(name=algorithim_name), + ], + profiles=[ + VectorSearchProfile( + name=self.vector_search_profile_name, + algorithm_configuration_name=algorithim_name, + vectorizer=vectorizer_name, + ) + ], + vectorizers=[ + CustomVectorizer( + name=vectorizer_name, + custom_web_api_parameters=CustomWebApiParameters( + uri=get_custom_skill_function_url("compass"), + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ), + ), + ], + ) + + return vector_search + + def deploy_index(self): + """This function deploys index""" + + index_fields = self.get_index_fields() + vector_search = self.get_compass_vector_search() + semantic_search = self.get_semantic_search() + index = SearchIndex( + name=self.index_name, + fields=index_fields, + vector_search=vector_search, + semantic_search=semantic_search, + ) + if self.rebuild: + self._search_index_client.delete_index(self.index_name) + self._search_index_client.create_or_update_index(index) + + print(f"{index.name} created") + + def deploy_skillset(self): + """This function deploys the skillset.""" + skills = self.get_skills() + index_projections = self.get_index_projections() + + skillset = SearchIndexerSkillset( + name=self.skillset_name, + description="Skillset to chunk documents and generating embeddings", + skills=skills, + index_projections=index_projections, + ) + + self._search_indexer_client.create_or_update_skillset(skillset) + print(f"{skillset.name} created") + + def deploy_data_source(self): + """This function deploys the data source.""" + data_source = self.get_data_source() + + result = self._search_indexer_client.create_or_update_data_source_connection( + data_source + ) + + print(f"Data source '{result.name}' created or updated") + + return result + + def deploy_indexer(self): + """This function deploys the indexer.""" + indexer = self.get_indexer() + + result = self._search_indexer_client.create_or_update_indexer(indexer) + + print(f"Indexer '{result.name}' created or updated") + + return result + + def run_indexer(self): + """This function runs the indexer.""" + self._search_indexer_client.run_indexer(self.indexer_name) + + print( + f"{self.indexer_name} is running. If queries return no results, please wait a bit and try again." + ) + + def reset_indexer(self): + """This function runs the indexer.""" + self._search_indexer_client.reset_indexer(self.indexer_name) + + print(f"{self.indexer_name} reset.") + + def deploy_synonym_map(self) -> list[SearchableField]: + synonym_maps = self.get_synonym_map_names() + if len(synonym_maps) > 0: + for synonym_map in synonym_maps: + try: + synonym_map = SynonymMap(name=synonym_map, synonyms="") + self._search_index_client.create_synonym_map(synonym_map) + except HttpResponseError: + print("Unable to deploy synonym map as it already exists.") + + def deploy(self): + """This function deploys the whole AI search pipeline.""" + self.deploy_data_source() + self.deploy_synonym_map() + self.deploy_index() + self.deploy_skillset() + self.deploy_indexer() + + print(f"{str(self.indexer_type.value)} deployed") diff --git a/aisearch-skillset/deploy.py b/aisearch-skillset/deploy.py new file mode 100644 index 0000000..d98e099 --- /dev/null +++ b/aisearch-skillset/deploy.py @@ -0,0 +1,80 @@ +import argparse +from environment import get_search_endpoint, get_managed_identity_id, get_search_key,get_key_vault_url +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential,ManagedIdentityCredential,EnvironmentCredential +from azure.keyvault.secrets import SecretClient +from inquiry_document import InquiryDocumentAISearch + + +def main(args): + endpoint = get_search_endpoint() + + try: + credential = DefaultAzureCredential(managed_identity_client_id =get_managed_identity_id()) + # initializing key vault client + client = SecretClient(vault_url=get_key_vault_url(), credential=credential) + print("Using managed identity credential") + except Exception as e: + print(e) + credential = ( + AzureKeyCredential(get_search_key(client=client)) + ) + print("Using Azure Key credential") + + if args.indexer_type == "inquiry": + # Deploy the inquiry index + index_config = InquiryDocumentAISearch( + endpoint=endpoint, + credential=credential, + suffix=args.suffix, + rebuild=args.rebuild, + enable_page_by_chunking=args.enable_page_chunking + ) + elif args.indexer_type == "summary": + # Deploy the summarises index + index_config = SummaryDocumentAISearch( + endpoint=endpoint, + credential=credential, + suffix=args.suffix, + rebuild=args.rebuild, + enable_page_by_chunking=args.enable_page_chunking + ) + elif args.indexer_type == "glossary": + # Deploy business glossary index + index_config = BusinessGlossaryAISearch(endpoint, credential) + + index_config.deploy() + + if args.rebuild: + index_config.reset_indexer() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process some arguments.") + parser.add_argument( + "--indexer_type", + type=str, + required=True, + help="Type of Indexer want to deploy. inquiry/summary/glossary", + ) + parser.add_argument( + "--rebuild", + type=bool, + required=False, + help="Whether want to delete and rebuild the index", + ) + parser.add_argument( + "--enable_page_chunking", + type=bool, + required=False, + help="Whether want to enable chunking by page in adi skill, if no value is passed considered False", + ) + parser.add_argument( + "--suffix", + type=str, + required=False, + help="Suffix to be attached to indexer objects", + ) + + args = parser.parse_args() + main(args) diff --git a/aisearch-skillset/environment.py b/aisearch-skillset/environment.py new file mode 100644 index 0000000..7503a68 --- /dev/null +++ b/aisearch-skillset/environment.py @@ -0,0 +1,192 @@ +"""Module providing environment definition""" +import os +from dotenv import find_dotenv, load_dotenv +from enum import Enum + +load_dotenv(find_dotenv()) + + +class IndexerType(Enum): + """The type of the indexer""" + + INQUIRY_DOCUMENT = "inquiry-document" + SUMMARY_DOCUMENT = "summary-document" + BUSINESS_GLOSSARY = "business-glossary" + +# key vault +def get_key_vault_url() ->str: + """ + This function returns key vault url + """ + return os.environ.get("KeyVault__Url") + +# managed identity id +def get_managed_identity_id() -> str: + """ + This function returns maanged identity id + """ + return os.environ.get("AIService__AzureSearchOptions__ManagedIdentity__ClientId") + + +def get_managed_identity_fqname() -> str: + """ + This function returns maanged identity name + """ + return os.environ.get("AIService__AzureSearchOptions__ManagedIdentity__FQName") + + +# function app details +def get_function_app_authresourceid() -> str: + """ + This function returns apps registration in microsoft entra id + """ + return os.environ.get("FunctionApp__AuthResourceId") + + +def get_function_app_end_point() -> str: + """ + This function returns function app endpoint + """ + return os.environ.get("FunctionApp__Endpoint") + +def get_function_app_key() -> str: + """ + This function returns function app key + """ + return os.environ.get("FunctionApp__Key") + +def get_function_app_compass_function() -> str: + """ + This function returns function app compass function name + """ + return os.environ.get("FunctionApp__Compass__FunctionName") + + +def get_function_app_pre_embedding_cleaner_function() -> str: + """ + This function returns function app data cleanup function name + """ + return os.environ.get("FunctionApp__PreEmbeddingCleaner__FunctionName") + + +def get_function_app_adi_function() -> str: + """ + This function returns function app adi name + """ + return os.environ.get("FunctionApp__DocumentIntelligence__FunctionName") + + +def get_function_app_custom_split_function() -> str: + """ + This function returns function app adi name + """ + return os.environ.get("FunctionApp__CustomTextSplit__FunctionName") + + +def get_function_app_keyphrase_extractor_function() -> str: + """ + This function returns function app keyphrase extractor name + """ + return os.environ.get("FunctionApp__KeyphraseExtractor__FunctionName") + + +def get_function_app_ocr_function() -> str: + """ + This function returns function app ocr name + """ + return os.environ.get("FunctionApp__Ocr__FunctionName") + + +# search +def get_search_endpoint() -> str: + """ + This function returns azure ai search service endpoint + """ + return os.environ.get("AIService__AzureSearchOptions__Endpoint") + + +def get_search_user_assigned_identity() -> str: + """ + This function returns azure ai search service endpoint + """ + return os.environ.get("AIService__AzureSearchOptions__UserAssignedIdentity") + + +def get_search_key(client) -> str: + """ + This function returns azure ai search service admin key + """ + search_service_key_secret_name = str(os.environ.get("AIService__AzureSearchOptions__name")) + "-PrimaryKey" + retrieved_secret = client.get_secret(search_service_key_secret_name) + return retrieved_secret.value + +def get_search_key_secret() -> str: + """ + This function returns azure ai search service admin key + """ + return os.environ.get("AIService__AzureSearchOptions__Key__Secret") + + +def get_search_embedding_model_dimensions(indexer_type: IndexerType) -> str: + """ + This function returns dimensions for embedding model + """ + + normalised_indexer_type = ( + indexer_type.value.replace("-", " ").title().replace(" ", "") + ) + + return os.environ.get( + f"AIService__AzureSearchOptions__{normalised_indexer_type}__EmbeddingDimensions" + ) + +def get_blob_connection_string() -> str: + """ + This function returns azure blob storage connection string + """ + return os.environ.get("StorageAccount__ConnectionString") + +def get_fq_blob_connection_string() -> str: + """ + This function returns azure blob storage connection string + """ + return os.environ.get("StorageAccount__FQEndpoint") + + +def get_blob_container_name(indexer_type: str) -> str: + """ + This function returns azure blob container name + """ + normalised_indexer_type = ( + indexer_type.value.replace("-", " ").title().replace(" ", "") + ) + return os.environ.get(f"StorageAccount__{normalised_indexer_type}__Container") + + +def get_custom_skill_function_url(skill_type: str): + """ + Get the function app url that is hosting the custom skill + """ + url = ( + get_function_app_end_point() + + "/api/function_name?code=" + + get_function_app_key() + ) + if skill_type == "compass": + url = url.replace("function_name", get_function_app_compass_function()) + elif skill_type == "pre_embedding_cleaner": + url = url.replace( + "function_name", get_function_app_pre_embedding_cleaner_function() + ) + elif skill_type == "adi": + url = url.replace("function_name", get_function_app_adi_function()) + elif skill_type == "split": + url = url.replace("function_name", get_function_app_custom_split_function()) + elif skill_type == "keyphraseextraction": + url = url.replace( + "function_name", get_function_app_keyphrase_extractor_function() + ) + elif skill_type == "ocr": + url = url.replace("function_name", get_function_app_ocr_function()) + + return url diff --git a/aisearch-skillset/inquiry_document.py b/aisearch-skillset/inquiry_document.py new file mode 100644 index 0000000..3f9dd0a --- /dev/null +++ b/aisearch-skillset/inquiry_document.py @@ -0,0 +1,320 @@ +from azure.search.documents.indexes.models import ( + SearchFieldDataType, + SearchField, + SearchableField, + SemanticField, + SemanticPrioritizedFields, + SemanticConfiguration, + SemanticSearch, + InputFieldMappingEntry, + SearchIndexer, + FieldMapping, + IndexingParameters, + IndexingParametersConfiguration, + BlobIndexerImageAction, + SearchIndexerIndexProjections, + SearchIndexerIndexProjectionSelector, + SearchIndexerIndexProjectionsParameters, + IndexProjectionMode, + SimpleField, + BlobIndexerDataToExtract, + IndexerExecutionEnvironment, + BlobIndexerPDFTextRotationAlgorithm, +) +from ai_search import AISearch +from environment import ( + get_search_embedding_model_dimensions, + IndexerType, +) + + +class InquiryDocumentAISearch(AISearch): + """This class is used to deploy the inquiry document index.""" + + def __init__( + self, + endpoint, + credential, + suffix=None, + rebuild=False, + enable_page_by_chunking=False, + ): + super().__init__(endpoint, credential, suffix, rebuild) + + self.indexer_type = IndexerType.INQUIRY_DOCUMENT + if enable_page_by_chunking is not None: + self.enable_page_by_chunking = enable_page_by_chunking + else: + self.enable_page_by_chunking = False + + # explicitly setting it to false no matter what output comes in + # might be removed later + # self.enable_page_by_chunking = False + + def get_index_fields(self) -> list[SearchableField]: + """This function returns the index fields for inquiry document. + + Returns: + list[SearchableField]: The index fields for inquiry document""" + + fields = [ + SimpleField(name="Id", type=SearchFieldDataType.String, filterable=True), + SearchableField( + name="Title", type=SearchFieldDataType.String, filterable=True + ), + SearchableField( + name="DealId", + type=SearchFieldDataType.String, + sortable=True, + filterable=True, + facetable=True, + ), + SearchableField( + name="OracleId", + type=SearchFieldDataType.String, + sortable=True, + filterable=True, + facetable=True, + ), + SearchableField( + name="ChunkId", + type=SearchFieldDataType.String, + key=True, + analyzer_name="keyword", + ), + SearchableField( + name="Chunk", + type=SearchFieldDataType.String, + sortable=False, + filterable=False, + facetable=False, + ), + SearchableField( + name="Section", + type=SearchFieldDataType.String, + collection=True, + ), + SearchField( + name="ChunkEmbedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + vector_search_dimensions=get_search_embedding_model_dimensions( + self.indexer_type + ), + vector_search_profile_name=self.vector_search_profile_name, + ), + SearchableField( + name="Keywords", type=SearchFieldDataType.String, collection=True + ), + SearchableField( + name="SourceUrl", + type=SearchFieldDataType.String, + sortable=True, + filterable=True, + facetable=True, + ), + SearchableField( + name="AdditionalMetadata", + type=SearchFieldDataType.String, + sortable=True, + filterable=True, + facetable=True, + ), + ] + + if self.enable_page_by_chunking: + fields.extend( + [ + SearchableField( + name="PageNumber", + type=SearchFieldDataType.Int64, + sortable=True, + filterable=True, + facetable=True, + ) + ] + ) + + return fields + + def get_semantic_search(self) -> SemanticSearch: + """This function returns the semantic search configuration for inquiry document + + Returns: + SemanticSearch: The semantic search configuration""" + + semantic_config = SemanticConfiguration( + name=self.semantic_config_name, + prioritized_fields=SemanticPrioritizedFields( + title_field=SemanticField(field_name="Title"), + content_fields=[SemanticField(field_name="Chunk")], + keywords_fields=[ + SemanticField(field_name="Keywords"), + SemanticField(field_name="Section"), + ], + ), + ) + + semantic_search = SemanticSearch(configurations=[semantic_config]) + + return semantic_search + + def get_skills(self): + """This function returns the skills for inquiry document""" + + adi_skill = self.get_adi_skill(self.enable_page_by_chunking) + + text_split_skill = self.get_text_split_skill( + "/document", "/document/extracted_content/content" + ) + + pre_embedding_cleaner_skill = self.get_pre_embedding_cleaner_skill( + "/document/pages/*", "/document/pages/*", self.enable_page_by_chunking + ) + + key_phrase_extraction_skill = self.get_key_phrase_extraction_skill( + "/document/pages/*", "/document/pages/*/cleaned_chunk" + ) + + embedding_skill = self.get_compass_vector_custom_skill( + "/document/pages/*", "/document/pages/*/cleaned_chunk" + ) + + if self.enable_page_by_chunking: + skills = [ + adi_skill, + pre_embedding_cleaner_skill, + key_phrase_extraction_skill, + embedding_skill, + ] + else: + skills = [ + adi_skill, + text_split_skill, + pre_embedding_cleaner_skill, + key_phrase_extraction_skill, + embedding_skill, + ] + + return skills + + def get_index_projections(self) -> SearchIndexerIndexProjections: + """This function returns the index projections for inquiry document.""" + mappings =[ + InputFieldMappingEntry( + name="Chunk", source="/document/pages/*/chunk" + ), + InputFieldMappingEntry( + name="ChunkEmbedding", + source="/document/pages/*/vector", + ), + InputFieldMappingEntry( + name="Title", + source="/document/Title" + ), + InputFieldMappingEntry( + name="DealId", + source="/document/DealId" + ), + InputFieldMappingEntry( + name="OracleId", + source="/document/OracleId" + ), + InputFieldMappingEntry( + name="SourceUrl", + source="/document/SourceUrl" + ), + InputFieldMappingEntry( + name="Keywords", + source="/document/pages/*/keywords" + ), + InputFieldMappingEntry( + name="AdditionalMetadata", + source="/document/AdditionalMetadata", + ), + InputFieldMappingEntry( + name="Section", + source="/document/pages/*/eachsection" + ) + ] + + if self.enable_page_by_chunking: + mappings.extend( + [ + InputFieldMappingEntry( + name="PageNumber", source="/document/pages/*/page_no" + ) + ] + ) + + index_projections = SearchIndexerIndexProjections( + selectors=[ + SearchIndexerIndexProjectionSelector( + target_index_name=self.index_name, + parent_key_field_name="Id", + source_context="/document/pages/*", + mappings=mappings + ), + ], + parameters=SearchIndexerIndexProjectionsParameters( + projection_mode=IndexProjectionMode.SKIP_INDEXING_PARENT_DOCUMENTS + ), + ) + + return index_projections + + def get_indexer(self) -> SearchIndexer: + """This function returns the indexer for inquiry document. + + Returns: + SearchIndexer: The indexer for inquiry document""" + if self.test: + schedule = None + batch_size = 4 + else: + schedule = {"interval": "PT15M"} + batch_size = 16 + + indexer_parameters = IndexingParameters( + batch_size=batch_size, + configuration=IndexingParametersConfiguration( + # image_action=BlobIndexerImageAction.GENERATE_NORMALIZED_IMAGE_PER_PAGE, + data_to_extract=BlobIndexerDataToExtract.ALL_METADATA, + query_timeout=None, + # allow_skillset_to_read_file_data=True, + execution_environment=IndexerExecutionEnvironment.PRIVATE, + # pdf_text_rotation_algorithm=BlobIndexerPDFTextRotationAlgorithm.DETECT_ANGLES, + fail_on_unprocessable_document=False, + fail_on_unsupported_content_type=False, + index_storage_metadata_only_for_oversized_documents=True, + indexed_file_name_extensions=".pdf,.pptx,.docx", + ), + max_failed_items=5, + ) + + indexer = SearchIndexer( + name=self.indexer_name, + description="Indexer to index documents and generate embeddings", + skillset_name=self.skillset_name, + target_index_name=self.index_name, + data_source_name=self.data_source_name, + schedule=schedule, + field_mappings=[ + FieldMapping( + source_field_name="metadata_storage_name", target_field_name="Title" + ), + FieldMapping(source_field_name="Deal_ID", target_field_name="DealId"), + FieldMapping( + source_field_name="Oracle_ID", target_field_name="OracleId" + ), + FieldMapping( + source_field_name="SharePointUrl", target_field_name="SourceUrl" + ), + FieldMapping( + source_field_name="Additional_Metadata", + target_field_name="AdditionalMetadata", + ), + ], + parameters=indexer_parameters, + ) + + return indexer diff --git a/function_apps/common/ai_search.py b/function_apps/common/ai_search.py new file mode 100644 index 0000000..1bba829 --- /dev/null +++ b/function_apps/common/ai_search.py @@ -0,0 +1,127 @@ +from azure.search.documents.indexes.aio import SearchIndexerClient, SearchIndexClient +from azure.search.documents.aio import SearchClient +from azure.search.documents.indexes.models import SynonymMap +from azure.identity import DefaultAzureCredential +from azure.core.exceptions import HttpResponseError +import logging +import os +from enum import Enum +from openai import AsyncAzureOpenAI +from azure.search.documents.models import VectorizedQuery + + +class IndexerStatusEnum(Enum): + RETRIGGER = "RETRIGGER" + RUNNING = "RUNNING" + SUCCESS = "SUCCESS" + + +class AISearchHelper: + def __init__(self): + self._client_id = os.environ["FunctionApp__ClientId"] + + self._endpoint = os.environ["AIService__AzureSearchOptions__Endpoint"] + + async def get_index_client(self): + credential = DefaultAzureCredential(managed_identity_client_id=self._client_id) + + return SearchIndexClient(self._endpoint, credential) + + async def get_indexer_client(self): + credential = DefaultAzureCredential(managed_identity_client_id=self._client_id) + + return SearchIndexerClient(self._endpoint, credential) + + async def get_search_client(self, index_name): + credential = DefaultAzureCredential(managed_identity_client_id=self._client_id) + + return SearchClient(self._endpoint, index_name, credential) + + async def upload_synonym_map(self, synonym_map_name: str, synonyms: str): + index_client = await self.get_index_client() + async with index_client: + try: + await index_client.delete_synonym_map(synonym_map_name) + except HttpResponseError as e: + logging.error("Unable to delete synonym map %s", e) + + logging.info("Synonyms: %s", synonyms) + synonym_map = SynonymMap(name=synonym_map_name, synonyms=synonyms) + await index_client.create_synonym_map(synonym_map) + + async def get_indexer_status(self, indexer_name): + indexer_client = await self.get_indexer_client() + async with indexer_client: + try: + status = await indexer_client.get_indexer_status(indexer_name) + + last_execution_result = status.last_result + + if last_execution_result.status == "inProgress": + return IndexerStatusEnum.RUNNING, last_execution_result.start_time + elif last_execution_result.status in ["success", "transientFailure"]: + return IndexerStatusEnum.SUCCESS, last_execution_result.start_time + else: + return IndexerStatusEnum.RETRIGGER, last_execution_result.start_time + except HttpResponseError as e: + logging.error("Unable to get indexer status %s", e) + + async def trigger_indexer(self, indexer_name): + indexer_client = await self.get_indexer_client() + async with indexer_client: + try: + await indexer_client.run_indexer(indexer_name) + except HttpResponseError as e: + logging.error("Unable to run indexer %s", e) + + async def search_index( + self, index_name, semantic_config, search_text, deal_id=None + ): + """Search the index using the provided search text.""" + async with AsyncAzureOpenAI( + # This is the default and can be omitted + api_key=os.environ["AIService__Compass_Key"], + azure_endpoint=os.environ["AIService__Compass_Endpoint"], + api_version="2023-03-15-preview", + ) as open_ai_client: + embeddings = await open_ai_client.embeddings.create( + model=os.environ["AIService__Compass_Models__Embedding"], + input=search_text, + ) + + # Extract the embedding vector + embedding_vector = embeddings.data[0].embedding + + vector_query = VectorizedQuery( + vector=embedding_vector, + k_nearest_neighbors=5, + fields="ChunkEmbedding", + ) + + if deal_id: + filter_expression = f"DealId eq '{deal_id}'" + else: + filter_expression = None + + logging.info(f"Filter Expression: {filter_expression}") + + search_client = await self.get_search_client(index_name) + async with search_client: + results = await search_client.search( + top=3, + query_type="semantic", + semantic_configuration_name=semantic_config, + search_text=search_text, + select="Title,Chunk", + vector_queries=[vector_query], + filter=filter_expression, + ) + + documents = [ + document + async for result in results.by_page() + async for document in result + ] + + logging.info(f"Documents: {documents}") + return documents diff --git a/function_apps/indexer/adi_2_aisearch.py b/function_apps/indexer/adi_2_aisearch.py new file mode 100644 index 0000000..e0542fb --- /dev/null +++ b/function_apps/indexer/adi_2_aisearch.py @@ -0,0 +1,460 @@ +import base64 +from azure.core.credentials import AzureKeyCredential +from azure.ai.documentintelligence.aio import DocumentIntelligenceClient +from azure.ai.documentintelligence.models import AnalyzeResult, ContentFormat +import os +import re +import asyncio +import fitz +from PIL import Image +import io +import aiohttp +import logging +from common.storage_account import StorageAccountHelper +import concurrent.futures +import json + + +def crop_image_from_pdf_page(pdf_path, page_number, bounding_box): + """ + Crops a region from a given page in a PDF and returns it as an image. + + :param pdf_path: Path to the PDF file. + :param page_number: The page number to crop from (0-indexed). + :param bounding_box: A tuple of (x0, y0, x1, y1) coordinates for the bounding box. + :return: A PIL Image of the cropped area. + """ + doc = fitz.open(pdf_path) + page = doc.load_page(page_number) + + # Cropping the page. The rect requires the coordinates in the format (x0, y0, x1, y1). + bbx = [x * 72 for x in bounding_box] + rect = fitz.Rect(bbx) + pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72), clip=rect) + + img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + doc.close() + return img + + +def clean_adi_markdown(markdown_text: str, page_no:int,remove_irrelevant_figures=False): + """Clean Markdown text extracted by the Azure Document Intelligence service. + + Args: + ----- + markdown_text (str): The original Markdown text. + remove_irrelevant_figures (bool): Whether to remove all figures or just irrelevant ones. + + Returns: + -------- + str: The cleaned Markdown text. + """ + + # # Remove the page number comment + # page_number_pattern = r"" + # cleaned_text = re.sub(page_number_pattern, "", markdown_text) + + # # Replace the page header comment with its content + # page_header_pattern = r"" + # cleaned_text = re.sub( + # page_header_pattern, lambda match: match.group(1), cleaned_text + # ) + + # # Replace the page footer comment with its content + # page_footer_pattern = r"" + # cleaned_text = re.sub( + # page_footer_pattern, lambda match: match.group(1), cleaned_text + # ) + output_dict = {} + comment_patterns = r"||" + cleaned_text = re.sub(comment_patterns, "", markdown_text, flags=re.DOTALL) + + combined_pattern = r'(.*?)\n===|\n## ?(.*?)\n|\n### ?(.*?)\n' + doc_metadata = re.findall(combined_pattern, cleaned_text, re.DOTALL) + doc_metadata = [match for group in doc_metadata for match in group if match] + + + if remove_irrelevant_figures: + # Remove irrelevant figures + irrelevant_figure_pattern = ( + r"
.*?.*?
\s*" + ) + cleaned_text = re.sub( + irrelevant_figure_pattern, "", cleaned_text, flags=re.DOTALL + ) + + # Replace ':selected:' with a new line + cleaned_text = re.sub(r":(selected|unselected):", "\n", cleaned_text) + output_dict['content'] = cleaned_text + output_dict['section'] = doc_metadata + + # add page number when chunk by page is enabled + if page_no> -1: + output_dict['page_number'] = page_no + + return output_dict + + +def update_figure_description(md_content, img_description, idx): + """ + Updates the figure description in the Markdown content. + + Args: + md_content (str): The original Markdown content. + img_description (str): The new description for the image. + idx (int): The index of the figure. + + Returns: + str: The updated Markdown content with the new figure description. + """ + + # The substring you're looking for + start_substring = f"![](figures/{idx})" + end_substring = "" + new_string = f'' + + new_md_content = md_content + # Find the start and end indices of the part to replace + start_index = md_content.find(start_substring) + if start_index != -1: # if start_substring is found + start_index += len( + start_substring + ) # move the index to the end of start_substring + end_index = md_content.find(end_substring, start_index) + if end_index != -1: # if end_substring is found + # Replace the old string with the new string + new_md_content = ( + md_content[:start_index] + new_string + md_content[end_index:] + ) + + return new_md_content + + +async def understand_image_with_vlm(image_base64): + """ + Sends a base64-encoded image to a VLM (Vision Language Model) endpoint for financial analysis. + + Args: + image_base64 (str): The base64-encoded string representation of the image. + + Returns: + str: The response from the VLM, which is either a financial analysis or a statement indicating the image is not useful. + """ + # prompt = "Describe the image ONLY IF it is useful for financial analysis. Otherwise, say 'NOT USEFUL IMAGE' and NOTHING ELSE. " + prompt = "Perform financial analysis of the image ONLY IF the image is of graph, chart, flowchart or table. Otherwise, say 'NOT USEFUL IMAGE' and NOTHING ELSE. " + headers = {"Content-Type": "application/json"} + data = {"prompt": prompt, "image": image_base64} + vlm_endpoint = os.environ["AIServices__VLM__Endpoint"] + async with aiohttp.ClientSession() as session: + async with session.post( + vlm_endpoint, headers=headers, json=data, timeout=30 + ) as response: + response_data = await response.json() + response_text = response_data["response"].split("")[0] + + if ( + "not useful for financial analysis" in response_text + or "NOT USEFUL IMAGE" in response_text + ): + return "Irrelevant Image" + else: + return response_text + + +def pil_image_to_base64(image, image_format="JPEG"): + """ + Converts a PIL image to a base64-encoded string. + + Args: + image (PIL.Image.Image): The image to be converted. + image_format (str): The format to save the image in. Defaults to "JPEG". + + Returns: + str: The base64-encoded string representation of the image. + """ + if image.mode == "RGBA" and image_format == "JPEG": + image = image.convert("RGB") + buffered = io.BytesIO() + image.save(buffered, format=image_format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +async def process_figures_from_extracted_content( + file_path: str, markdown_content: str, figures: list, page_number: None | int = None +) -> str: + """Process the figures extracted from the content using ADI and send them for analysis. + + Args: + ----- + file_path (str): The path to the PDF file. + markdown_content (str): The extracted content in Markdown format. + figures (list): The list of figures extracted by the Azure Document Intelligence service. + page_number (int): The page number to process. If None, all pages are processed. + + Returns: + -------- + str: The updated Markdown content with the figure descriptions.""" + for idx, figure in enumerate(figures): + img_description = "" + logging.debug(f"Figure #{idx} has the following spans: {figure.spans}") + + caption_region = figure.caption.bounding_regions if figure.caption else [] + for region in figure.bounding_regions: + # Skip the region if it is not on the specified page + if page_number is not None and region.page_number != page_number: + continue + + if region not in caption_region: + # To learn more about bounding regions, see https://aka.ms/bounding-region + bounding_box = ( + region.polygon[0], # x0 (left) + region.polygon[1], # y0 (top) + region.polygon[4], # x1 (right) + region.polygon[5], # y1 (bottom) + ) + cropped_image = crop_image_from_pdf_page( + file_path, region.page_number - 1, bounding_box + ) # page_number is 1-indexed3 + + image_base64 = pil_image_to_base64(cropped_image) + + img_description += await understand_image_with_vlm(image_base64) + logging.info(f"\tDescription of figure {idx}: {img_description}") + + markdown_content = update_figure_description( + markdown_content, img_description, idx + ) + + return markdown_content + + +def create_page_wise_content(result: AnalyzeResult) -> list: + """Create a list of page-wise content extracted by the Azure Document Intelligence service. + + Args: + ----- + result (AnalyzeResult): The result of the document analysis. + + Returns: + -------- + list: A list of page-wise content extracted by the Azure Document Intelligence service. + """ + + page_wise_content = [] + page_numbers = [] + page_number = 0 + for page in result.pages: + page_content = result.content[ + page.spans[0]["offset"] : page.spans[0]["offset"] + page.spans[0]["length"] + ] + page_wise_content.append(page_content) + page_number+=1 + page_numbers.append(page_number) + + return page_wise_content,page_numbers + + +async def analyse_document(file_path: str) -> AnalyzeResult: + """Analyse a document using the Azure Document Intelligence service. + + Args: + ----- + file_path (str): The path to the document to analyse. + + Returns: + -------- + AnalyzeResult: The result of the document analysis.""" + with open(file_path, "rb") as f: + file_read = f.read() + # base64_encoded_file = base64.b64encode(file_read).decode("utf-8") + + async with DocumentIntelligenceClient( + endpoint=os.environ["AIService__Services__Endpoint"], + credential=AzureKeyCredential(os.environ["AIService__Services__Key"]), + ) as document_intelligence_client: + poller = await document_intelligence_client.begin_analyze_document( + model_id="prebuilt-layout", + analyze_request=file_read, + output_content_format=ContentFormat.MARKDOWN, + content_type="application/octet-stream", + ) + + result = await poller.result() + + if result is None or result.content is None or result.pages is None: + raise ValueError( + "Failed to analyze the document with Azure Document Intelligence." + ) + + return result + + +async def process_adi_2_ai_search(record: dict, chunk_by_page: bool = False) -> dict: + logging.info("Python HTTP trigger function processed a request.") + + storage_account_helper = StorageAccountHelper() + + try: + source = record["data"]["source"] + logging.info(f"Request Body: {record}") + except KeyError: + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to extract data with ADI. Pass a valid source in the request body.", + } + ], + "warnings": None, + } + else: + logging.info(f"Source: {source}") + + try: + source_parts = source.split("/") + blob = "/".join(source_parts[4:]) + logging.info(f"Blob: {blob}") + + container = source_parts[3] + + file_extension = blob.split(".")[-1] + target_file_name = f"{record['recordId']}.{file_extension}" + + temp_file_path, _ = await storage_account_helper.download_blob_to_temp_dir( + blob, container, target_file_name + ) + logging.info(temp_file_path) + except Exception as e: + logging.error(f"Failed to download the blob: {e}") + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to download the blob. Check the source and try again. {e}", + } + ], + "warnings": None, + } + + try: + result = await analyse_document(temp_file_path) + except Exception as e: + logging.error(e) + logging.info("Sleeping for 10 seconds and retrying") + await asyncio.sleep(10) + try: + result = await analyse_document(temp_file_path) + except ValueError as inner_e: + logging.error(inner_e) + logging.error( + f"Failed to analyze the document with Azure Document Intelligence: {e}" + ) + logging.error( + "Failed to analyse %s with Azure Document Intelligence.", blob + ) + await storage_account_helper.add_metadata_to_blob( + blob, container, {"AzureSearch_Skip": "true"} + ) + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to analyze the document with Azure Document Intelligence. This blob will now be skipped {inner_e}", + } + ], + "warnings": None, + } + except Exception as inner_e: + logging.error(inner_e) + logging.error( + "Failed to analyse %s with Azure Document Intelligence.", blob + ) + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to analyze the document with Azure Document Intelligence. Check the logs and try again. {inner_e}", + } + ], + "warnings": None, + } + + try: + if chunk_by_page: + markdown_content,page_no = create_page_wise_content(result) + else: + markdown_content = result.content + + # Remove this line when VLM is ready + content_with_figures = markdown_content + + # if chunk_by_page: + # tasks = [ + # process_figures_from_extracted_content( + # temp_file_path, page_content, result.figures, page_number=idx + # ) + # for idx, page_content in enumerate(markdown_content) + # ] + # content_with_figures = await asyncio.gather(*tasks) + # else: + # content_with_figures = await process_figures_from_extracted_content( + # temp_file_path, markdown_content, result.figures + # ) + + # Remove remove_irrelevant_figures=True when VLM is ready + if chunk_by_page: + cleaned_result = [] + with concurrent.futures.ProcessPoolExecutor() as executor: + results = executor.map(clean_adi_markdown,content_with_figures, page_no,[False] * len(content_with_figures)) + + for cleaned_content in results: + cleaned_result.append(cleaned_content) + + # with concurrent.futures.ProcessPoolExecutor() as executor: + # futures = { + # executor.submit( + # clean_adi_markdown, page_content, False + # ): page_content + # for page_content in content_with_figures + # } + # for future in concurrent.futures.as_completed(futures): + # cleaned_result.append(future.result()) + else: + cleaned_result = clean_adi_markdown( + content_with_figures, page_no=-1,remove_irrelevant_figures=False + ) + except Exception as e: + logging.error(e) + logging.error(f"Failed to process the extracted content: {e}") + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to process the extracted content. Check the logs and try again. {e}", + } + ], + "warnings": None, + } + + logging.info("Document Extracted") + logging.info(f"Result: {cleaned_result}") + + src = { + "recordId": record["recordId"], + "data": {"extracted_content": cleaned_result}, + } + + json_str = json.dumps(src, indent=4) + + logging.info(f"final output: {json_str}") + + return { + "recordId": record["recordId"], + "data": {"extracted_content": cleaned_result}, + } diff --git a/function_apps/indexer/function_app.py b/function_apps/indexer/function_app.py new file mode 100644 index 0000000..12d5d5b --- /dev/null +++ b/function_apps/indexer/function_app.py @@ -0,0 +1,296 @@ +from datetime import datetime, timedelta, timezone +import azure.functions as func +import logging +import json +import asyncio + +from adi_2_ai_search import process_adi_2_ai_search +from common.service_bus import ServiceBusHelper +from pre_embedding_cleaner import process_pre_embedding_cleaner + +from text_split import process_text_split +from ai_search_2_compass import process_ai_search_2_compass +from key_phrase_extraction import process_key_phrase_extraction +from ocr import process_ocr +from pending_index_completion import process_pending_index_completion +from pending_index_trigger import process_pending_index_trigger + +from common.payloads.pending_index_trigger import PendingIndexTriggerPayload + +from common.payloads.header import TaskEnum + +logging.basicConfig(level=logging.INFO) +app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION) + + +@app.route(route="text_split", methods=[func.HttpMethod.POST]) +async def text_split(req: func.HttpRequest) -> func.HttpResponse: + """Extract the content from a document using ADI.""" + + try: + req_body = req.get_json() + values = req_body.get("values") + text_split_config = req.headers + except ValueError: + return func.HttpResponse( + "Please valid Custom Skill Payload in the request body", status_code=400 + ) + else: + logging.debug(f"Input Values: {values}") + + record_tasks = [] + + for value in values: + record_tasks.append( + asyncio.create_task(process_text_split(value, text_split_config)) + ) + + results = await asyncio.gather(*record_tasks) + logging.debug(f"Results: {results}") + + return func.HttpResponse( + json.dumps({"values": results}), + status_code=200, + mimetype="application/json", + ) + + +@app.route(route="ai_search_2_compass", methods=[func.HttpMethod.POST]) +async def ai_search_2_compass(req: func.HttpRequest) -> func.HttpResponse: + logging.info("Python HTTP trigger function processed a request.") + + """HTTP trigger for AI Search 2 Compass function. + + Args: + req (func.HttpRequest): The HTTP request object. + + Returns: + func.HttpResponse: The HTTP response object.""" + logging.info("Python HTTP trigger function processed a request.") + + try: + req_body = req.get_json() + values = req_body.get("values") + except ValueError: + return func.HttpResponse( + "Please valid Custom Skill Payload in the request body", status_code=400 + ) + else: + logging.debug("Input Values: %s", values) + + record_tasks = [] + + for value in values: + record_tasks.append(asyncio.create_task(process_ai_search_2_compass(value))) + + results = await asyncio.gather(*record_tasks) + logging.debug("Results: %s", results) + vectorised_tasks = {"values": results} + + return func.HttpResponse( + json.dumps(vectorised_tasks), status_code=200, mimetype="application/json" + ) + + +@app.route(route="adi_2_ai_search", methods=[func.HttpMethod.POST]) +async def adi_2_ai_search(req: func.HttpRequest) -> func.HttpResponse: + """Extract the content from a document using ADI.""" + + try: + req_body = req.get_json() + values = req_body.get("values") + adi_config = req.headers + + chunk_by_page = adi_config.get("chunk_by_page", "False").lower() == "true" + logging.info(f"Chunk by Page: {chunk_by_page}") + except ValueError: + return func.HttpResponse( + "Please valid Custom Skill Payload in the request body", status_code=400 + ) + else: + logging.debug("Input Values: %s", values) + + record_tasks = [] + + for value in values: + record_tasks.append( + asyncio.create_task( + process_adi_2_ai_search(value, chunk_by_page=chunk_by_page) + ) + ) + + results = await asyncio.gather(*record_tasks) + logging.debug("Results: %s", results) + + return func.HttpResponse( + json.dumps({"values": results}), + status_code=200, + mimetype="application/json", + ) + + +@app.route(route="pre_embedding_cleaner", methods=[func.HttpMethod.POST]) +async def pre_embedding_cleaner(req: func.HttpRequest) -> func.HttpResponse: + """HTTP trigger for data cleanup function. + + Args: + req (func.HttpRequest): The HTTP request object. + + Returns: + func.HttpResponse: The HTTP response object.""" + logging.info("Python HTTP trigger data cleanup function processed a request.") + + try: + req_body = req.get_json() + values = req_body.get("values") + except ValueError: + return func.HttpResponse( + "Please valid Custom Skill Payload in the request body", status_code=400 + ) + else: + logging.debug("Input Values: %s", values) + + record_tasks = [] + + for value in values: + record_tasks.append( + asyncio.create_task(process_pre_embedding_cleaner(value)) + ) + + results = await asyncio.gather(*record_tasks) + logging.debug("Results: %s", results) + cleaned_tasks = {"values": results} + + return func.HttpResponse( + json.dumps(cleaned_tasks), status_code=200, mimetype="application/json" + ) + + +@app.route(route="keyphrase_extractor", methods=[func.HttpMethod.POST]) +async def keyphrase_extractor(req: func.HttpRequest) -> func.HttpResponse: + """HTTP trigger for data cleanup function. + + Args: + req (func.HttpRequest): The HTTP request object. + + Returns: + func.HttpResponse: The HTTP response object.""" + logging.info("Python HTTP trigger data cleanup function processed a request.") + + try: + req_body = req.get_json() + values = req_body.get("values") + logging.info(req_body) + except ValueError: + return func.HttpResponse( + "Please valid Custom Skill Payload in the request body", status_code=400 + ) + else: + logging.debug("Input Values: %s", values) + + record_tasks = [] + + for value in values: + record_tasks.append( + asyncio.create_task(process_key_phrase_extraction(value)) + ) + + results = await asyncio.gather(*record_tasks) + logging.debug("Results: %s", results) + cleaned_tasks = {"values": results} + + return func.HttpResponse( + json.dumps(cleaned_tasks), status_code=200, mimetype="application/json" + ) + + +@app.route(route="ocr", methods=[func.HttpMethod.POST]) +async def ocr(req: func.HttpRequest) -> func.HttpResponse: + """HTTP trigger for data cleanup function. + + Args: + req (func.HttpRequest): The HTTP request object. + + Returns: + func.HttpResponse: The HTTP response object.""" + logging.info("Python HTTP trigger data cleanup function processed a request.") + + try: + req_body = req.get_json() + values = req_body.get("values") + except ValueError: + return func.HttpResponse( + "Please valid Custom Skill Payload in the request body", status_code=400 + ) + else: + logging.debug("Input Values: %s", values) + + record_tasks = [] + + for value in values: + record_tasks.append(asyncio.create_task(process_ocr(value))) + + results = await asyncio.gather(*record_tasks) + logging.debug("Results: %s", results) + cleaned_tasks = {"values": results} + + return func.HttpResponse( + json.dumps(cleaned_tasks), status_code=200, mimetype="application/json" + ) + + +@app.service_bus_queue_trigger( + arg_name="msg", + queue_name="pending_index_trigger", + connection="ServiceBusTrigger", +) +async def pending_index_trigger(msg: func.ServiceBusMessage): + logging.info( + f"trigger-indexer: Python ServiceBus queue trigger processed message: {msg}" + ) + try: + payload = PendingIndexTriggerPayload.from_service_bus_message(msg) + await process_pending_index_trigger(payload) + except ValueError as ve: + logging.error(f"ValueError: {ve}") + except Exception as e: + logging.error(f"Error processing ServiceBus message: {e}") + + if "On-demand indexer invocation is permitted every 180 seconds" in str(e): + logging.warning( + f"Indexer invocation limit reached: {e}. Scheduling a retry." + ) + service_bus_helper = ServiceBusHelper() + message = PendingIndexTriggerPayload( + header=payload.header, body=payload.body, errors=[] + ) + queue = TaskEnum.PENDING_INDEX_TRIGGER.value + minutes = 2 ** (11 - payload.header.retries_remaining) + enqueue_time = datetime.now(timezone.utc) + timedelta(minutes=minutes) + await service_bus_helper.send_message_to_service_bus_queue( + queue, message, enqueue_time=enqueue_time + ) + else: + raise e + + +@app.service_bus_queue_trigger( + arg_name="msg", + queue_name="pending_index_completion", + connection="ServiceBusTrigger", +) +async def pending_index_completion(msg: func.ServiceBusMessage): + logging.info( + f"indexer-polling-trigger: Python ServiceBus queue trigger processed message: {msg}" + ) + + try: + payload = PendingIndexTriggerPayload.from_service_bus_message(msg) + await process_pending_index_completion(payload) + except ValueError as ve: + logging.error(f"ValueError: {ve}") + except Exception as e: + logging.error(f"Error processing ServiceBus message: {e}") + if "The operation has timed out" in str(e): + logging.error("The operation has timed out.") + raise e diff --git a/function_apps/indexer/key_phrase_extraction.py b/function_apps/indexer/key_phrase_extraction.py new file mode 100644 index 0000000..c6ab40e --- /dev/null +++ b/function_apps/indexer/key_phrase_extraction.py @@ -0,0 +1,112 @@ +import logging +import json +import os +from azure.ai.textanalytics.aio import TextAnalyticsClient +from azure.core.exceptions import HttpResponseError +from azure.core.credentials import AzureKeyCredential +import asyncio + +MAX_TEXT_ELEMENTS = 5120 + +def split_document(document, max_size): + """Split a document into chunks of max_size.""" + return [document[i:i + max_size] for i in range(0, len(document), max_size)] + +async def extract_key_phrases_from_text(data: list[str],max_key_phrase_count:int) -> list[str]: + logging.info("Python HTTP trigger function processed a request.") + + max_retries = 5 + key_phrase_list = [] + text_analytics_client = TextAnalyticsClient( + endpoint=os.environ["AIService__Services__Endpoint"], + credential=AzureKeyCredential(os.environ["AIService__Services__Key"]), + ) + + try: + async with text_analytics_client: + retries = 0 + while retries < max_retries: + try: + # Split large documents + split_documents = [] + for doc in data: + if len(doc) > MAX_TEXT_ELEMENTS: + split_documents.extend(split_document(doc, MAX_TEXT_ELEMENTS)) + else: + split_documents.append(doc) + result = await text_analytics_client.extract_key_phrases(split_documents) + for idx,doc in enumerate(result): + if not doc.is_error: + key_phrase_list.extend(doc.key_phrases[:max_key_phrase_count]) + else: + raise Exception(f"Document {idx} error: {doc.error}") + break # Exit the loop if the request is successful + except HttpResponseError as e: + if e.status_code == 429: # Rate limiting error + retries += 1 + wait_time = 2 ** retries # Exponential backoff + print(f"Rate limit exceeded. Retrying in {wait_time} seconds...") + await asyncio.sleep(wait_time) + else: + raise Exception(f"An error occurred: {e}") + except Exception as e: + raise Exception(f"An error occurred: {e}") + + return key_phrase_list + + +async def process_key_phrase_extraction(record: dict,max_key_phrase_count:int =5 ) -> dict: + """Extract key phrases using azure ai services. + + Args: + record (dict): The record to process. + max_key_phrase_count(int): no of keywords to return + + Returns: + dict: extracted key words.""" + + try: + json_str = json.dumps(record, indent=4) + + logging.info(f"key phrase extraction Input: {json_str}") + extracted_record = { + "recordId": record["recordId"], + "data": {}, + "errors": None, + "warnings": None, + } + extracted_record["data"]["keyPhrases"] = await extract_key_phrases_from_text( + [record["data"]["text"]],max_key_phrase_count + ) + except Exception as e: + logging.error("key phrase extraction Error: %s", e) + await asyncio.sleep(10) + try: + extracted_record = { + "recordId": record["recordId"], + "data": {}, + "errors": None, + "warnings": None, + } + extracted_record["data"][ + "keyPhrases" + ] = await extract_key_phrases_from_text([record["data"]["text"]],max_key_phrase_count) + except Exception as inner_e: + logging.error("key phrase extraction Error: %s", inner_e) + logging.error( + "Failed to extract key phrase. Check function app logs for more details of exact failure." + ) + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to extract key phrase. Check function app logs for more details of exact failure." + } + ], + "warnings": None, + } + json_str = json.dumps(extracted_record, indent=4) + + logging.info(f"key phrase extraction output: {json_str}") + return extracted_record diff --git a/function_apps/indexer/pre_embedding_cleaner.py b/function_apps/indexer/pre_embedding_cleaner.py new file mode 100644 index 0000000..2fdf87a --- /dev/null +++ b/function_apps/indexer/pre_embedding_cleaner.py @@ -0,0 +1,144 @@ +import logging +import json +import string +import nltk +import re +from nltk.tokenize import word_tokenize + +nltk.download("punkt") +nltk.download("stopwords") + +import re + +# Configure logging +logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') + +def get_section(cleaned_text:str) -> list: + """ + Returns the section details from the content + + Args: + cleaned_text: The input text + + Returns: + list: The sections related to text + + """ + combined_pattern = r'(.*?)\n===|\n## ?(.*?)\n|\n### ?(.*?)\n' + doc_metadata = re.findall(combined_pattern, cleaned_text, re.DOTALL) + doc_metadata = [match for group in doc_metadata for match in group if match] + return doc_metadata + +def remove_markdown_tags(text:str, tag_patterns:dict) ->str: + """ + Remove specified Markdown tags from the text, keeping the contents of the tags. + + Args: + text: The input text containing Markdown tags. + tag_patterns: A dictionary where keys are tags and values are their specific patterns. + + Returns: + str: The text with specified tags removed. + """ + try: + for tag, pattern in tag_patterns.items(): + try: + # Replace the tags using the specific pattern, keeping the content inside the tags + text = re.sub(pattern, r'\1', text, flags=re.DOTALL) + except re.error as e: + logging.error(f"Regex error for tag '{tag}': {e}") + except Exception as e: + logging.error(f"An error occurred in remove_markdown_tags: {e}") + return text + +def clean_text(src_text: str) -> str: + """This function performs following cleanup activities on the text, remove all unicode characters + remove line spacing,remove stop words, normalize characters + + Args: + src_text (str): The text to cleanup. + + Returns: + str: The clean text.""" + + try: + # Define specific patterns for each tag + tag_patterns = { + "figurecontent": r"", + "figure": r"
(.*?)
", + "figures": r"\(figures/\d+\)(.*?)\(figures/\d+\)", + "figcaption": r"
(.*?)
", + } + cleaned_text = remove_markdown_tags(src_text, tag_patterns) + + # remove line breaks + cleaned_text = re.sub(r"\n", "", cleaned_text) + + # remove stopwords + tokens = word_tokenize(cleaned_text, "english") + stop_words = nltk.corpus.stopwords.words("english") + filtered_tokens = [word for word in tokens if word not in stop_words] + cleaned_text = " ".join(filtered_tokens) + + # remove special characters + cleaned_text = re.sub(r"[^a-zA-Z\s]", "", cleaned_text) + + # remove extra white spaces + cleaned_text = " ".join([word for word in cleaned_text.split()]) + + # case normalization + cleaned_text = cleaned_text.lower() + except Exception as e: + logging.error(f"An error occurred in clean_text: {e}") + return "" + return cleaned_text + + +async def process_pre_embedding_cleaner(record: dict) -> dict: + """Cleanup the data using standard python libraries. + + Args: + record (dict): The record to cleanup. + + Returns: + dict: The clean record.""" + + try: + json_str = json.dumps(record, indent=4) + + logging.info(f"embedding cleaner Input: {json_str}") + + cleaned_record = { + "recordId": record["recordId"], + "data": {}, + "errors": None, + "warnings": None, + } + + # scenarios when page by chunking is enabled + if isinstance(record["data"]["chunk"],dict): + cleaned_record["data"]["cleaned_chunk"] = clean_text(record["data"]["chunk"]["content"]) + cleaned_record["data"]["chunk"] = record["data"]["chunk"]["content"] + cleaned_record["data"]["section"] = record["data"]["chunk"]["section"] + cleaned_record["data"]["page_number"] = record["data"]["chunk"]["page_number"] + else: + cleaned_record["data"]["cleaned_chunk"] = clean_text(record["data"]["chunk"]) + cleaned_record["data"]["chunk"] = record["data"]["chunk"] + cleaned_record["data"]["section"] = get_section(record["data"]["chunk"]) + + except Exception as e: + logging.error("string cleanup Error: %s", e) + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to cleanup data. Check function app logs for more details of exact failure." + } + ], + "warnings": None, + } + json_str = json.dumps(cleaned_record, indent=4) + + logging.info(f"embedding cleaner output: {json_str}") + return cleaned_record diff --git a/function_apps/indexer/requirements.txt b/function_apps/indexer/requirements.txt new file mode 100644 index 0000000..48c9837 --- /dev/null +++ b/function_apps/indexer/requirements.txt @@ -0,0 +1,26 @@ +# DO NOT include azure-functions-worker in this file +# The Python Worker is managed by Azure Functions platform +# Manually managing azure-functions-worker may cause unexpected issues +python-dotenv +azure-functions +openai +azure-storage-blob +pandas +azure-identity +openpyxl +regex +nltk==3.8.1 +bs4 +azure-search +azure-search-documents +azure-ai-documentintelligence +azure-ai-textanalytics +azure-ai-vision-imageanalysis +PyMuPDF +pillow +torch +aiohttp +spacy==3.7.5 +transformers +scikit-learn +en-core-web-md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1.tar.gz From 8d177a003b33b6b5433a91797149822fb0bd121f Mon Sep 17 00:00:00 2001 From: priyal1508 <54278892+priyal1508@users.noreply.github.com> Date: Thu, 5 Sep 2024 19:11:57 +0530 Subject: [PATCH 2/4] changes for common scripts --- aisearch-skillset/deploy.py | 13 ------ aisearch-skillset/inquiry_document.py | 16 ++++---- function_apps/common/payloads/error.py | 20 ++++++++++ function_apps/common/payloads/header.py | 40 +++++++++++++++++++ function_apps/common/payloads/payload.py | 20 ++++++++++ .../payloads/pending_index_completion.py | 40 +++++++++++++++++++ .../common/payloads/pennding_index_trigger.py | 32 +++++++++++++++ .../indexer/pending_index_completion.py | 0 8 files changed, 160 insertions(+), 21 deletions(-) create mode 100644 function_apps/common/payloads/error.py create mode 100644 function_apps/common/payloads/header.py create mode 100644 function_apps/common/payloads/payload.py create mode 100644 function_apps/common/payloads/pending_index_completion.py create mode 100644 function_apps/common/payloads/pennding_index_trigger.py create mode 100644 function_apps/indexer/pending_index_completion.py diff --git a/aisearch-skillset/deploy.py b/aisearch-skillset/deploy.py index d98e099..1b2190b 100644 --- a/aisearch-skillset/deploy.py +++ b/aisearch-skillset/deploy.py @@ -30,19 +30,6 @@ def main(args): rebuild=args.rebuild, enable_page_by_chunking=args.enable_page_chunking ) - elif args.indexer_type == "summary": - # Deploy the summarises index - index_config = SummaryDocumentAISearch( - endpoint=endpoint, - credential=credential, - suffix=args.suffix, - rebuild=args.rebuild, - enable_page_by_chunking=args.enable_page_chunking - ) - elif args.indexer_type == "glossary": - # Deploy business glossary index - index_config = BusinessGlossaryAISearch(endpoint, credential) - index_config.deploy() if args.rebuild: diff --git a/aisearch-skillset/inquiry_document.py b/aisearch-skillset/inquiry_document.py index 3f9dd0a..b70251e 100644 --- a/aisearch-skillset/inquiry_document.py +++ b/aisearch-skillset/inquiry_document.py @@ -63,14 +63,14 @@ def get_index_fields(self) -> list[SearchableField]: name="Title", type=SearchFieldDataType.String, filterable=True ), SearchableField( - name="DealId", + name="ID1", type=SearchFieldDataType.String, sortable=True, filterable=True, facetable=True, ), SearchableField( - name="OracleId", + name="ID2", type=SearchFieldDataType.String, sortable=True, filterable=True, @@ -212,12 +212,12 @@ def get_index_projections(self) -> SearchIndexerIndexProjections: source="/document/Title" ), InputFieldMappingEntry( - name="DealId", - source="/document/DealId" + name="ID1", + source="/document/ID1" ), InputFieldMappingEntry( - name="OracleId", - source="/document/OracleId" + name="ID2", + source="/document/ID2" ), InputFieldMappingEntry( name="SourceUrl", @@ -302,9 +302,9 @@ def get_indexer(self) -> SearchIndexer: FieldMapping( source_field_name="metadata_storage_name", target_field_name="Title" ), - FieldMapping(source_field_name="Deal_ID", target_field_name="DealId"), + FieldMapping(source_field_name="ID1", target_field_name="ID1"), FieldMapping( - source_field_name="Oracle_ID", target_field_name="OracleId" + source_field_name="ID2", target_field_name="ID2" ), FieldMapping( source_field_name="SharePointUrl", target_field_name="SourceUrl" diff --git a/function_apps/common/payloads/error.py b/function_apps/common/payloads/error.py new file mode 100644 index 0000000..49e456e --- /dev/null +++ b/function_apps/common/payloads/error.py @@ -0,0 +1,20 @@ +from typing import Optional +from pydantic import BaseModel, Field, ConfigDict +from datetime import datetime, timezone + + +class Error(BaseModel): + """Error item model""" + + code: str = Field(..., description="The error code") + message: str = Field(..., description="The error message") + details: Optional[str] = Field( + None, description="Detailed error information from Python" + ) + timestamp: Optional[datetime] = Field( + ..., + description="Creation timestamp in UTC", + default_factory=lambda: datetime.now(timezone.utc), + ) + + __config__ = ConfigDict(extra="ignore") diff --git a/function_apps/common/payloads/header.py b/function_apps/common/payloads/header.py new file mode 100644 index 0000000..e7a521c --- /dev/null +++ b/function_apps/common/payloads/header.py @@ -0,0 +1,40 @@ +from pydantic import BaseModel, Field, ConfigDict +from datetime import datetime, timezone +from enum import Enum + + +class DataTypeEnum(Enum): + """Type enum""" + + BUSINESS_GLOSSARY = "business_glossary" + SUMMARY = "summary" + + +class TaskEnum(Enum): + """Task enum""" + + PENDING_INDEX_COMPLETION = "pending_index_completion" + PENDING_INDEX_TRIGGER = "pending_index_trigger" + PENDING_SUMMARY_GENERATION = "pending_summary_generation" + + +class Header(BaseModel): + """Header model""" + + creation_timestamp: datetime = Field( + ..., + description="Creation timestamp in UTC", + default_factory=lambda: datetime.now(timezone.utc), + ) + last_processed_timestamp: datetime = Field( + ..., + description="Last processed timestamp in UTC", + default_factory=lambda: datetime.now(timezone.utc), + ) + retries_remaining: int = Field( + description="Number of retries remaining", default=10 + ) + data_type: DataTypeEnum = Field(..., description="Data type") + task: TaskEnum = Field(..., description="Task name") + + __config__ = ConfigDict(extra="ignore") diff --git a/function_apps/common/payloads/payload.py b/function_apps/common/payloads/payload.py new file mode 100644 index 0000000..fb2f4f9 --- /dev/null +++ b/function_apps/common/payloads/payload.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, ConfigDict +import logging + + +class Payload(BaseModel): + """Body model""" + + @classmethod + def from_service_bus_message(cls, message): + """ + Create a Payload object from a ServiceBusMessage object. + + :param message: The ServiceBusMessage object. + :return: The Body object. + """ + message = message.get_body().decode("utf-8") + logging.info(f"ServiceBus message: {message}") + return cls.model_validate_json(message) + + __config__ = ConfigDict(extra="ignore") diff --git a/function_apps/common/payloads/pending_index_completion.py b/function_apps/common/payloads/pending_index_completion.py new file mode 100644 index 0000000..8aa0335 --- /dev/null +++ b/function_apps/common/payloads/pending_index_completion.py @@ -0,0 +1,40 @@ +from pydantic import BaseModel, Field, ConfigDict +from datetime import datetime, timezone +from typing import Optional, List + +from common.payloads.header import Header +from common.payloads.error import Error +from common.payloads.payload import Payload + + +class PendingIndexCompletionBody(BaseModel): + """Body model""" + + indexer: str = Field(..., description="The indexer to trigger") + deal_id: Optional[int] = Field(None, description="The deal ID") + blob_storage_url: Optional[str] = Field( + ..., description="The URL to the blob storage" + ) + deal_name: Optional[str] = Field( + None, description="The text name for the integer deal ID" + ) + business_unit: Optional[str] = Field(None, description="The business unit") + indexer_start_time: Optional[datetime] = Field( + ..., + description="The time the indexer was triggered successfully", + default_factory=lambda: datetime.now(timezone.utc), + ) + + __config__ = ConfigDict(extra="ignore") + + +class PendingIndexCompletionPayload(Payload): + """Pending Index Trigger model""" + + header: Header = Field(..., description="Header information") + body: PendingIndexCompletionBody = Field(..., description="Body information") + errors: List[Error] | None = Field( + ..., description="List of errors", default_factory=list + ) + + __config__ = ConfigDict(extra="ignore") diff --git a/function_apps/common/payloads/pennding_index_trigger.py b/function_apps/common/payloads/pennding_index_trigger.py new file mode 100644 index 0000000..2a519d9 --- /dev/null +++ b/function_apps/common/payloads/pennding_index_trigger.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel, Field, ConfigDict +from typing import Optional, List + +from common.payloads.header import Header +from common.payloads.error import Error +from common.payloads.payload import Payload + + +class PendingIndexTriggerBody(BaseModel): + """Body model""" + + indexer: str = Field(..., description="The indexer to trigger") + deal_id: Optional[int] = Field(None, description="The deal ID") + blob_storage_url: str = Field(..., description="The URL to the blob storage") + deal_name: Optional[str] = Field( + None, description="The text name for the integer deal ID" + ) + business_unit: Optional[str] = Field(None, description="The business unit") + + __config__ = ConfigDict(extra="ignore") + + +class PendingIndexTriggerPayload(Payload): + """Pending Index Trigger model""" + + header: Header = Field(..., description="Header information") + body: PendingIndexTriggerBody = Field(..., description="Body information") + errors: List[Error] | None = Field( + ..., description="List of errors", default_factory=list + ) + + __config__ = ConfigDict(extra="ignore") diff --git a/function_apps/indexer/pending_index_completion.py b/function_apps/indexer/pending_index_completion.py new file mode 100644 index 0000000..e69de29 From 461702883d1f9c15725d87a396bcb7862a318092 Mon Sep 17 00:00:00 2001 From: priyal1508 <54278892+priyal1508@users.noreply.github.com> Date: Thu, 5 Sep 2024 19:14:10 +0530 Subject: [PATCH 3/4] fixing bugs --- .../{pennding_index_trigger.py => pending_index_trigger.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename function_apps/common/payloads/{pennding_index_trigger.py => pending_index_trigger.py} (100%) diff --git a/function_apps/common/payloads/pennding_index_trigger.py b/function_apps/common/payloads/pending_index_trigger.py similarity index 100% rename from function_apps/common/payloads/pennding_index_trigger.py rename to function_apps/common/payloads/pending_index_trigger.py From de48566b1c9c01478ad533e7cc48cca27e420d03 Mon Sep 17 00:00:00 2001 From: priyal1508 <54278892+priyal1508@users.noreply.github.com> Date: Thu, 5 Sep 2024 19:21:08 +0530 Subject: [PATCH 4/4] changes in fodler structure --- {aisearch-skillset => ai_search_with_adi}/ai_search.py | 0 {aisearch-skillset => ai_search_with_adi}/deploy.py | 0 {aisearch-skillset => ai_search_with_adi}/environment.py | 0 .../function_apps}/common/ai_search.py | 0 .../function_apps}/common/payloads/error.py | 0 .../function_apps}/common/payloads/header.py | 0 .../function_apps}/common/payloads/payload.py | 0 .../function_apps}/common/payloads/pending_index_completion.py | 0 .../function_apps}/common/payloads/pending_index_trigger.py | 0 .../function_apps}/indexer/adi_2_aisearch.py | 0 .../function_apps}/indexer/function_app.py | 0 .../function_apps}/indexer/key_phrase_extraction.py | 0 .../function_apps}/indexer/pending_index_completion.py | 0 .../function_apps}/indexer/pre_embedding_cleaner.py | 0 .../function_apps}/indexer/requirements.txt | 0 {aisearch-skillset => ai_search_with_adi}/inquiry_document.py | 0 16 files changed, 0 insertions(+), 0 deletions(-) rename {aisearch-skillset => ai_search_with_adi}/ai_search.py (100%) rename {aisearch-skillset => ai_search_with_adi}/deploy.py (100%) rename {aisearch-skillset => ai_search_with_adi}/environment.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/common/ai_search.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/error.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/header.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/payload.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/pending_index_completion.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/pending_index_trigger.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/indexer/adi_2_aisearch.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/indexer/function_app.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/indexer/key_phrase_extraction.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/indexer/pending_index_completion.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/indexer/pre_embedding_cleaner.py (100%) rename {function_apps => ai_search_with_adi/function_apps}/indexer/requirements.txt (100%) rename {aisearch-skillset => ai_search_with_adi}/inquiry_document.py (100%) diff --git a/aisearch-skillset/ai_search.py b/ai_search_with_adi/ai_search.py similarity index 100% rename from aisearch-skillset/ai_search.py rename to ai_search_with_adi/ai_search.py diff --git a/aisearch-skillset/deploy.py b/ai_search_with_adi/deploy.py similarity index 100% rename from aisearch-skillset/deploy.py rename to ai_search_with_adi/deploy.py diff --git a/aisearch-skillset/environment.py b/ai_search_with_adi/environment.py similarity index 100% rename from aisearch-skillset/environment.py rename to ai_search_with_adi/environment.py diff --git a/function_apps/common/ai_search.py b/ai_search_with_adi/function_apps/common/ai_search.py similarity index 100% rename from function_apps/common/ai_search.py rename to ai_search_with_adi/function_apps/common/ai_search.py diff --git a/function_apps/common/payloads/error.py b/ai_search_with_adi/function_apps/common/payloads/error.py similarity index 100% rename from function_apps/common/payloads/error.py rename to ai_search_with_adi/function_apps/common/payloads/error.py diff --git a/function_apps/common/payloads/header.py b/ai_search_with_adi/function_apps/common/payloads/header.py similarity index 100% rename from function_apps/common/payloads/header.py rename to ai_search_with_adi/function_apps/common/payloads/header.py diff --git a/function_apps/common/payloads/payload.py b/ai_search_with_adi/function_apps/common/payloads/payload.py similarity index 100% rename from function_apps/common/payloads/payload.py rename to ai_search_with_adi/function_apps/common/payloads/payload.py diff --git a/function_apps/common/payloads/pending_index_completion.py b/ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py similarity index 100% rename from function_apps/common/payloads/pending_index_completion.py rename to ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py diff --git a/function_apps/common/payloads/pending_index_trigger.py b/ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py similarity index 100% rename from function_apps/common/payloads/pending_index_trigger.py rename to ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py diff --git a/function_apps/indexer/adi_2_aisearch.py b/ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py similarity index 100% rename from function_apps/indexer/adi_2_aisearch.py rename to ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py diff --git a/function_apps/indexer/function_app.py b/ai_search_with_adi/function_apps/indexer/function_app.py similarity index 100% rename from function_apps/indexer/function_app.py rename to ai_search_with_adi/function_apps/indexer/function_app.py diff --git a/function_apps/indexer/key_phrase_extraction.py b/ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py similarity index 100% rename from function_apps/indexer/key_phrase_extraction.py rename to ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py diff --git a/function_apps/indexer/pending_index_completion.py b/ai_search_with_adi/function_apps/indexer/pending_index_completion.py similarity index 100% rename from function_apps/indexer/pending_index_completion.py rename to ai_search_with_adi/function_apps/indexer/pending_index_completion.py diff --git a/function_apps/indexer/pre_embedding_cleaner.py b/ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py similarity index 100% rename from function_apps/indexer/pre_embedding_cleaner.py rename to ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py diff --git a/function_apps/indexer/requirements.txt b/ai_search_with_adi/function_apps/indexer/requirements.txt similarity index 100% rename from function_apps/indexer/requirements.txt rename to ai_search_with_adi/function_apps/indexer/requirements.txt diff --git a/aisearch-skillset/inquiry_document.py b/ai_search_with_adi/inquiry_document.py similarity index 100% rename from aisearch-skillset/inquiry_document.py rename to ai_search_with_adi/inquiry_document.py