From 7029a1843ff3a9713c1011f260e5e5db1db1cf8a Mon Sep 17 00:00:00 2001
From: priyal1508 <54278892+priyal1508@users.noreply.github.com>
Date: Thu, 5 Sep 2024 10:52:29 +0530
Subject: [PATCH 1/4] changes for adi based skillset
---
aisearch-skillset/ai_search.py | 690 ++++++++++++++++++
aisearch-skillset/deploy.py | 80 ++
aisearch-skillset/environment.py | 192 +++++
aisearch-skillset/inquiry_document.py | 320 ++++++++
function_apps/common/ai_search.py | 127 ++++
function_apps/indexer/adi_2_aisearch.py | 460 ++++++++++++
function_apps/indexer/function_app.py | 296 ++++++++
.../indexer/key_phrase_extraction.py | 112 +++
.../indexer/pre_embedding_cleaner.py | 144 ++++
function_apps/indexer/requirements.txt | 26 +
10 files changed, 2447 insertions(+)
create mode 100644 aisearch-skillset/ai_search.py
create mode 100644 aisearch-skillset/deploy.py
create mode 100644 aisearch-skillset/environment.py
create mode 100644 aisearch-skillset/inquiry_document.py
create mode 100644 function_apps/common/ai_search.py
create mode 100644 function_apps/indexer/adi_2_aisearch.py
create mode 100644 function_apps/indexer/function_app.py
create mode 100644 function_apps/indexer/key_phrase_extraction.py
create mode 100644 function_apps/indexer/pre_embedding_cleaner.py
create mode 100644 function_apps/indexer/requirements.txt
diff --git a/aisearch-skillset/ai_search.py b/aisearch-skillset/ai_search.py
new file mode 100644
index 0000000..7573055
--- /dev/null
+++ b/aisearch-skillset/ai_search.py
@@ -0,0 +1,690 @@
+from abc import ABC, abstractmethod
+from azure.search.documents.indexes.models import (
+ SearchIndex,
+ SearchableField,
+ VectorSearch,
+ VectorSearchProfile,
+ HnswAlgorithmConfiguration,
+ SemanticSearch,
+ NativeBlobSoftDeleteDeletionDetectionPolicy,
+ HighWaterMarkChangeDetectionPolicy,
+ WebApiSkill,
+ CustomVectorizer,
+ CustomWebApiParameters,
+ SearchIndexer,
+ SearchIndexerSkillset,
+ SearchIndexerDataContainer,
+ SearchIndexerDataSourceConnection,
+ SearchIndexerDataSourceType,
+ SearchIndexerDataUserAssignedIdentity,
+ OutputFieldMappingEntry,
+ InputFieldMappingEntry,
+ SynonymMap,
+ DocumentExtractionSkill,
+ OcrSkill,
+ MergeSkill,
+ ConditionalSkill,
+ SplitSkill,
+)
+from azure.core.exceptions import HttpResponseError
+from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient
+from environment import (
+ get_fq_blob_connection_string,
+ get_blob_container_name,
+ get_custom_skill_function_url,
+ get_managed_identity_fqname,
+ get_function_app_authresourceid,
+ IndexerType,
+)
+
+
+class AISearch(ABC):
+ def __init__(
+ self,
+ endpoint: str,
+ credential,
+ suffix: str | None = None,
+ rebuild: bool | None = False,
+ ):
+ """Initialize the AI search class
+
+ Args:
+ endpoint (str): The search endpoint
+ credential (AzureKeyCredential): The search credential"""
+ self.indexer_type = None
+
+ if rebuild is not None:
+ self.rebuild = rebuild
+ else:
+ self.rebuild = False
+
+ if suffix is None:
+ self.suffix = ""
+ self.test = False
+ else:
+ self.suffix = f"-{suffix}-test"
+ self.test = True
+
+ self._search_indexer_client = SearchIndexerClient(endpoint, credential)
+ self._search_index_client = SearchIndexClient(endpoint, credential)
+
+ @property
+ def indexer_name(self):
+ return f"{str(self.indexer_type.value)}-indexer{self.suffix}"
+
+ @property
+ def skillset_name(self):
+ return f"{str(self.indexer_type.value)}-skillset{self.suffix}"
+
+ @property
+ def semantic_config_name(self):
+ return f"{str(self.indexer_type.value)}-semantic-config{self.suffix}"
+
+ @property
+ def index_name(self):
+ return f"{str(self.indexer_type.value)}-index{self.suffix}"
+
+ @property
+ def data_source_name(self):
+ blob_container_name = get_blob_container_name(self.indexer_type)
+ return f"{blob_container_name}-data-source{self.suffix}"
+
+ @property
+ def vector_search_profile_name(self):
+ return (
+ f"{str(self.indexer_type.value)}-compass-vector-search-profile{self.suffix}"
+ )
+
+ @abstractmethod
+ def get_index_fields(self) -> list[SearchableField]:
+ """Get the index fields for the indexer.
+
+ Returns:
+ list[SearchableField]: The index fields"""
+
+ @abstractmethod
+ def get_semantic_search(self) -> SemanticSearch:
+ """Get the semantic search configuration for the indexer.
+
+ Returns:
+ SemanticSearch: The semantic search configuration"""
+
+ @abstractmethod
+ def get_skills(self):
+ """Get the skillset for the indexer."""
+
+ @abstractmethod
+ def get_indexer(self) -> SearchIndexer:
+ """Get the indexer for the indexer."""
+
+ def get_index_projections(self):
+ """Get the index projections for the indexer."""
+ return None
+
+ def get_synonym_map_names(self):
+ return []
+
+ def get_user_assigned_managed_identity(
+ self,
+ ) -> SearchIndexerDataUserAssignedIdentity:
+ """Get user assigned managed identity details"""
+
+ user_assigned_identity = SearchIndexerDataUserAssignedIdentity(
+ user_assigned_identity=get_managed_identity_fqname()
+ )
+ return user_assigned_identity
+
+ def get_data_source(self) -> SearchIndexerDataSourceConnection:
+ """Get the data source for the indexer."""
+
+ if self.indexer_type == IndexerType.BUSINESS_GLOSSARY:
+ data_deletion_detection_policy = None
+ else:
+ data_deletion_detection_policy = (
+ NativeBlobSoftDeleteDeletionDetectionPolicy()
+ )
+
+ data_change_detection_policy = HighWaterMarkChangeDetectionPolicy(
+ high_water_mark_column_name="metadata_storage_last_modified"
+ )
+
+ container = SearchIndexerDataContainer(
+ name=get_blob_container_name(self.indexer_type)
+ )
+
+ data_source_connection = SearchIndexerDataSourceConnection(
+ name=self.data_source_name,
+ type=SearchIndexerDataSourceType.AZURE_BLOB,
+ connection_string=get_fq_blob_connection_string(),
+ container=container,
+ data_change_detection_policy=data_change_detection_policy,
+ data_deletion_detection_policy=data_deletion_detection_policy,
+ identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return data_source_connection
+
+ def get_compass_vector_custom_skill(
+ self, context, source, target_name="vector"
+ ) -> WebApiSkill:
+ """Get the custom skill for compass.
+
+ Args:
+ -----
+ context (str): The context of the skill
+ source (str): The source of the skill
+ target_name (str): The target name of the skill
+
+ Returns:
+ --------
+ WebApiSkill: The custom skill for compass"""
+
+ if self.test:
+ batch_size = 2
+ degree_of_parallelism = 2
+ else:
+ batch_size = 4
+ degree_of_parallelism = 8
+
+ embedding_skill_inputs = [
+ InputFieldMappingEntry(name="text", source=source),
+ ]
+ embedding_skill_outputs = [
+ OutputFieldMappingEntry(name="vector", target_name=target_name)
+ ]
+ # Limit the number of documents to be processed in parallel to avoid timing out on compass api
+ embedding_skill = WebApiSkill(
+ name="Compass Connector API",
+ description="Skill to generate embeddings via compass API connector",
+ context=context,
+ uri=get_custom_skill_function_url("compass"),
+ timeout="PT230S",
+ batch_size=batch_size,
+ degree_of_parallelism=degree_of_parallelism,
+ http_method="POST",
+ inputs=embedding_skill_inputs,
+ outputs=embedding_skill_outputs,
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return embedding_skill
+
+ def get_pre_embedding_cleaner_skill(
+ self, context, source, chunk_by_page=False, target_name="cleaned_chunk"
+ ) -> WebApiSkill:
+ """Get the custom skill for data cleanup.
+
+ Args:
+ -----
+ context (str): The context of the skill
+ inputs (List[InputFieldMappingEntry]): The inputs of the skill
+ outputs (List[OutputFieldMappingEntry]): The outputs of the skill
+
+ Returns:
+ --------
+ WebApiSkill: The custom skill for data cleanup"""
+
+ if self.test:
+ batch_size = 2
+ degree_of_parallelism = 2
+ else:
+ batch_size = 16
+ degree_of_parallelism = 16
+
+ pre_embedding_cleaner_skill_inputs = [
+ InputFieldMappingEntry(name="chunk", source=source)
+ ]
+
+ pre_embedding_cleaner_skill_outputs = [
+ OutputFieldMappingEntry(name="cleaned_chunk", target_name=target_name),
+ OutputFieldMappingEntry(name="chunk", target_name="chunk"),
+ OutputFieldMappingEntry(name="section", target_name="eachsection"),
+ ]
+
+ if chunk_by_page:
+ pre_embedding_cleaner_skill_outputs.extend(
+ [
+ OutputFieldMappingEntry(name="page_number", target_name="page_no"),
+ ]
+ )
+
+ pre_embedding_cleaner_skill = WebApiSkill(
+ name="Pre Embedding Cleaner Skill",
+ description="Skill to clean the data before sending to embedding",
+ context=context,
+ uri=get_custom_skill_function_url("pre_embedding_cleaner"),
+ timeout="PT230S",
+ batch_size=batch_size,
+ degree_of_parallelism=degree_of_parallelism,
+ http_method="POST",
+ inputs=pre_embedding_cleaner_skill_inputs,
+ outputs=pre_embedding_cleaner_skill_outputs,
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return pre_embedding_cleaner_skill
+
+ def get_text_split_skill(self, context, source) -> SplitSkill:
+ """Get the skill for text split.
+
+ Args:
+ -----
+ context (str): The context of the skill
+ inputs (List[InputFieldMappingEntry]): The inputs of the skill
+ outputs (List[OutputFieldMappingEntry]): The outputs of the skill
+
+ Returns:
+ --------
+ splitSKill: The skill for text split"""
+
+ text_split_skill = SplitSkill(
+ name="Text Split Skill",
+ description="Skill to split the text before sending to embedding",
+ context=context,
+ text_split_mode="pages",
+ maximum_page_length=2000,
+ page_overlap_length=500,
+ inputs=[InputFieldMappingEntry(name="text", source=source)],
+ outputs=[OutputFieldMappingEntry(name="textItems", target_name="pages")],
+ )
+
+ return text_split_skill
+
+ def get_custom_text_split_skill(
+ self,
+ context,
+ source,
+ text_split_mode="semantic",
+ maximum_page_length=1000,
+ separator=" ",
+ initial_threshold=0.7,
+ appending_threshold=0.6,
+ merging_threshold=0.6,
+ ) -> WebApiSkill:
+ """Get the custom skill for text split.
+
+ Args:
+ -----
+ context (str): The context of the skill
+ inputs (List[InputFieldMappingEntry]): The inputs of the skill
+ outputs (List[OutputFieldMappingEntry]): The outputs of the skill
+
+ Returns:
+ --------
+ WebApiSkill: The custom skill for text split"""
+
+ if self.test:
+ batch_size = 2
+ degree_of_parallelism = 2
+ else:
+ batch_size = 2
+ degree_of_parallelism = 6
+
+ text_split_skill_inputs = [
+ InputFieldMappingEntry(name="text", source=source),
+ ]
+
+ headers = {
+ "text_split_mode": text_split_mode,
+ "maximum_page_length": maximum_page_length,
+ "separator": separator,
+ "initial_threshold": initial_threshold,
+ "appending_threshold": appending_threshold,
+ "merging_threshold": merging_threshold,
+ }
+
+ text_split_skill = WebApiSkill(
+ name="Text Split Skill",
+ description="Skill to split the text before sending to embedding",
+ context=context,
+ uri=get_custom_skill_function_url("split"),
+ timeout="PT230S",
+ batch_size=batch_size,
+ degree_of_parallelism=degree_of_parallelism,
+ http_method="POST",
+ http_headers=headers,
+ inputs=text_split_skill_inputs,
+ outputs=[OutputFieldMappingEntry(name="chunks", target_name="pages")],
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return text_split_skill
+
+ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
+ """Get the custom skill for adi.
+
+ Returns:
+ --------
+ WebApiSkill: The custom skill for adi"""
+
+ if self.test:
+ batch_size = 1
+ degree_of_parallelism = 4
+ else:
+ batch_size = 1
+ degree_of_parallelism = 16
+
+ if chunk_by_page:
+ output = [
+ OutputFieldMappingEntry(name="extracted_content", target_name="pages")
+ ]
+ else:
+ output = [
+ OutputFieldMappingEntry(
+ name="extracted_content", target_name="extracted_content"
+ )
+ ]
+
+ adi_skill = WebApiSkill(
+ name="ADI Skill",
+ description="Skill to generate ADI",
+ context="/document",
+ uri=get_custom_skill_function_url("adi"),
+ timeout="PT230S",
+ batch_size=batch_size,
+ degree_of_parallelism=degree_of_parallelism,
+ http_method="POST",
+ http_headers={"chunk_by_page": chunk_by_page},
+ inputs=[
+ InputFieldMappingEntry(
+ name="source", source="/document/metadata_storage_path"
+ )
+ ],
+ outputs=output,
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return adi_skill
+
+ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
+ """Get the key phrase extraction skill.
+
+ Args:
+ -----
+ context (str): The context of the skill
+ source (str): The source of the skill
+
+ Returns:
+ --------
+ WebApiSkill: The key phrase extraction skill"""
+
+ if self.test:
+ batch_size = 4
+ degree_of_parallelism = 4
+ else:
+ batch_size = 16
+ degree_of_parallelism = 16
+
+ keyphrase_extraction_skill_inputs = [
+ InputFieldMappingEntry(name="text", source=source),
+ ]
+ keyphrase_extraction__skill_outputs = [
+ OutputFieldMappingEntry(name="keyPhrases", target_name="keywords")
+ ]
+ key_phrase_extraction_skill = WebApiSkill(
+ name="Key phrase extraction API",
+ description="Skill to extract keyphrases",
+ context=context,
+ uri=get_custom_skill_function_url("keyphraseextraction"),
+ timeout="PT230S",
+ batch_size=batch_size,
+ degree_of_parallelism=degree_of_parallelism,
+ http_method="POST",
+ inputs=keyphrase_extraction_skill_inputs,
+ outputs=keyphrase_extraction__skill_outputs,
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return key_phrase_extraction_skill
+
+ def get_document_extraction_skill(self, context, source) -> DocumentExtractionSkill:
+ """Get the document extraction utility skill.
+
+ Args:
+ -----
+ context (str): The context of the skill
+ source (str): The source of the skill
+
+ Returns:
+ --------
+ DocumentExtractionSkill: The document extraction utility skill"""
+
+ doc_extraction_skill = DocumentExtractionSkill(
+ description="Extraction skill to extract content from office docs like excel, ppt, doc etc",
+ context=context,
+ inputs=[InputFieldMappingEntry(name="file_data", source=source)],
+ outputs=[
+ OutputFieldMappingEntry(
+ name="content", target_name="extracted_content"
+ ),
+ OutputFieldMappingEntry(
+ name="normalized_images", target_name="extracted_normalized_images"
+ ),
+ ],
+ )
+
+ return doc_extraction_skill
+
+ def get_ocr_skill(self, context, source) -> OcrSkill:
+ """Get the ocr utility skill
+ Args:
+ -----
+ context (str): The context of the skill
+ source (str): The source of the skill
+
+ Returns:
+ --------
+ OcrSkill: The ocr skill"""
+
+ if self.test:
+ batch_size = 2
+ degree_of_parallelism = 2
+ else:
+ batch_size = 2
+ degree_of_parallelism = 2
+
+ ocr_skill_inputs = [
+ InputFieldMappingEntry(name="image", source=source),
+ ]
+ ocr__skill_outputs = [OutputFieldMappingEntry(name="text", target_name="text")]
+ ocr_skill = WebApiSkill(
+ name="ocr API",
+ description="Skill to extract text from images",
+ context=context,
+ uri=get_custom_skill_function_url("ocr"),
+ timeout="PT230S",
+ batch_size=batch_size,
+ degree_of_parallelism=degree_of_parallelism,
+ http_method="POST",
+ inputs=ocr_skill_inputs,
+ outputs=ocr__skill_outputs,
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ )
+
+ return ocr_skill
+
+ def get_merge_skill(self, context, source) -> MergeSkill:
+ """Get the merge
+ Args:
+ -----
+ context (str): The context of the skill
+ source (array): The source of the skill
+
+ Returns:
+ --------
+ mergeSkill: The merge skill"""
+
+ merge_skill = MergeSkill(
+ description="Merge skill for combining OCR'd and regular text",
+ context=context,
+ inputs=[
+ InputFieldMappingEntry(name="text", source=source[0]),
+ InputFieldMappingEntry(name="itemsToInsert", source=source[1]),
+ InputFieldMappingEntry(name="offsets", source=source[2]),
+ ],
+ outputs=[
+ OutputFieldMappingEntry(name="mergedText", target_name="merged_content")
+ ],
+ )
+
+ return merge_skill
+
+ def get_conditional_skill(self, context, source) -> ConditionalSkill:
+ """Get the merge
+ Args:
+ -----
+ context (str): The context of the skill
+ source (array): The source of the skill
+
+ Returns:
+ --------
+ ConditionalSkill: The conditional skill"""
+
+ conditional_skill = ConditionalSkill(
+ description="Select between OCR and Document Extraction output",
+ context=context,
+ inputs=[
+ InputFieldMappingEntry(name="condition", source=source[0]),
+ InputFieldMappingEntry(name="whenTrue", source=source[1]),
+ InputFieldMappingEntry(name="whenFalse", source=source[2]),
+ ],
+ outputs=[
+ OutputFieldMappingEntry(name="output", target_name="updated_content")
+ ],
+ )
+
+ return conditional_skill
+
+ def get_compass_vector_search(self) -> VectorSearch:
+ """Get the vector search configuration for compass.
+
+ Args:
+ indexer_type (str): The type of the indexer
+
+ Returns:
+ VectorSearch: The vector search configuration
+ """
+ vectorizer_name = (
+ f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}"
+ )
+ algorithim_name = f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}"
+
+ vector_search = VectorSearch(
+ algorithms=[
+ HnswAlgorithmConfiguration(name=algorithim_name),
+ ],
+ profiles=[
+ VectorSearchProfile(
+ name=self.vector_search_profile_name,
+ algorithm_configuration_name=algorithim_name,
+ vectorizer=vectorizer_name,
+ )
+ ],
+ vectorizers=[
+ CustomVectorizer(
+ name=vectorizer_name,
+ custom_web_api_parameters=CustomWebApiParameters(
+ uri=get_custom_skill_function_url("compass"),
+ auth_resource_id=get_function_app_authresourceid(),
+ auth_identity=self.get_user_assigned_managed_identity(),
+ ),
+ ),
+ ],
+ )
+
+ return vector_search
+
+ def deploy_index(self):
+ """This function deploys index"""
+
+ index_fields = self.get_index_fields()
+ vector_search = self.get_compass_vector_search()
+ semantic_search = self.get_semantic_search()
+ index = SearchIndex(
+ name=self.index_name,
+ fields=index_fields,
+ vector_search=vector_search,
+ semantic_search=semantic_search,
+ )
+ if self.rebuild:
+ self._search_index_client.delete_index(self.index_name)
+ self._search_index_client.create_or_update_index(index)
+
+ print(f"{index.name} created")
+
+ def deploy_skillset(self):
+ """This function deploys the skillset."""
+ skills = self.get_skills()
+ index_projections = self.get_index_projections()
+
+ skillset = SearchIndexerSkillset(
+ name=self.skillset_name,
+ description="Skillset to chunk documents and generating embeddings",
+ skills=skills,
+ index_projections=index_projections,
+ )
+
+ self._search_indexer_client.create_or_update_skillset(skillset)
+ print(f"{skillset.name} created")
+
+ def deploy_data_source(self):
+ """This function deploys the data source."""
+ data_source = self.get_data_source()
+
+ result = self._search_indexer_client.create_or_update_data_source_connection(
+ data_source
+ )
+
+ print(f"Data source '{result.name}' created or updated")
+
+ return result
+
+ def deploy_indexer(self):
+ """This function deploys the indexer."""
+ indexer = self.get_indexer()
+
+ result = self._search_indexer_client.create_or_update_indexer(indexer)
+
+ print(f"Indexer '{result.name}' created or updated")
+
+ return result
+
+ def run_indexer(self):
+ """This function runs the indexer."""
+ self._search_indexer_client.run_indexer(self.indexer_name)
+
+ print(
+ f"{self.indexer_name} is running. If queries return no results, please wait a bit and try again."
+ )
+
+ def reset_indexer(self):
+ """This function runs the indexer."""
+ self._search_indexer_client.reset_indexer(self.indexer_name)
+
+ print(f"{self.indexer_name} reset.")
+
+ def deploy_synonym_map(self) -> list[SearchableField]:
+ synonym_maps = self.get_synonym_map_names()
+ if len(synonym_maps) > 0:
+ for synonym_map in synonym_maps:
+ try:
+ synonym_map = SynonymMap(name=synonym_map, synonyms="")
+ self._search_index_client.create_synonym_map(synonym_map)
+ except HttpResponseError:
+ print("Unable to deploy synonym map as it already exists.")
+
+ def deploy(self):
+ """This function deploys the whole AI search pipeline."""
+ self.deploy_data_source()
+ self.deploy_synonym_map()
+ self.deploy_index()
+ self.deploy_skillset()
+ self.deploy_indexer()
+
+ print(f"{str(self.indexer_type.value)} deployed")
diff --git a/aisearch-skillset/deploy.py b/aisearch-skillset/deploy.py
new file mode 100644
index 0000000..d98e099
--- /dev/null
+++ b/aisearch-skillset/deploy.py
@@ -0,0 +1,80 @@
+import argparse
+from environment import get_search_endpoint, get_managed_identity_id, get_search_key,get_key_vault_url
+from azure.core.credentials import AzureKeyCredential
+from azure.identity import DefaultAzureCredential,ManagedIdentityCredential,EnvironmentCredential
+from azure.keyvault.secrets import SecretClient
+from inquiry_document import InquiryDocumentAISearch
+
+
+def main(args):
+ endpoint = get_search_endpoint()
+
+ try:
+ credential = DefaultAzureCredential(managed_identity_client_id =get_managed_identity_id())
+ # initializing key vault client
+ client = SecretClient(vault_url=get_key_vault_url(), credential=credential)
+ print("Using managed identity credential")
+ except Exception as e:
+ print(e)
+ credential = (
+ AzureKeyCredential(get_search_key(client=client))
+ )
+ print("Using Azure Key credential")
+
+ if args.indexer_type == "inquiry":
+ # Deploy the inquiry index
+ index_config = InquiryDocumentAISearch(
+ endpoint=endpoint,
+ credential=credential,
+ suffix=args.suffix,
+ rebuild=args.rebuild,
+ enable_page_by_chunking=args.enable_page_chunking
+ )
+ elif args.indexer_type == "summary":
+ # Deploy the summarises index
+ index_config = SummaryDocumentAISearch(
+ endpoint=endpoint,
+ credential=credential,
+ suffix=args.suffix,
+ rebuild=args.rebuild,
+ enable_page_by_chunking=args.enable_page_chunking
+ )
+ elif args.indexer_type == "glossary":
+ # Deploy business glossary index
+ index_config = BusinessGlossaryAISearch(endpoint, credential)
+
+ index_config.deploy()
+
+ if args.rebuild:
+ index_config.reset_indexer()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Process some arguments.")
+ parser.add_argument(
+ "--indexer_type",
+ type=str,
+ required=True,
+ help="Type of Indexer want to deploy. inquiry/summary/glossary",
+ )
+ parser.add_argument(
+ "--rebuild",
+ type=bool,
+ required=False,
+ help="Whether want to delete and rebuild the index",
+ )
+ parser.add_argument(
+ "--enable_page_chunking",
+ type=bool,
+ required=False,
+ help="Whether want to enable chunking by page in adi skill, if no value is passed considered False",
+ )
+ parser.add_argument(
+ "--suffix",
+ type=str,
+ required=False,
+ help="Suffix to be attached to indexer objects",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/aisearch-skillset/environment.py b/aisearch-skillset/environment.py
new file mode 100644
index 0000000..7503a68
--- /dev/null
+++ b/aisearch-skillset/environment.py
@@ -0,0 +1,192 @@
+"""Module providing environment definition"""
+import os
+from dotenv import find_dotenv, load_dotenv
+from enum import Enum
+
+load_dotenv(find_dotenv())
+
+
+class IndexerType(Enum):
+ """The type of the indexer"""
+
+ INQUIRY_DOCUMENT = "inquiry-document"
+ SUMMARY_DOCUMENT = "summary-document"
+ BUSINESS_GLOSSARY = "business-glossary"
+
+# key vault
+def get_key_vault_url() ->str:
+ """
+ This function returns key vault url
+ """
+ return os.environ.get("KeyVault__Url")
+
+# managed identity id
+def get_managed_identity_id() -> str:
+ """
+ This function returns maanged identity id
+ """
+ return os.environ.get("AIService__AzureSearchOptions__ManagedIdentity__ClientId")
+
+
+def get_managed_identity_fqname() -> str:
+ """
+ This function returns maanged identity name
+ """
+ return os.environ.get("AIService__AzureSearchOptions__ManagedIdentity__FQName")
+
+
+# function app details
+def get_function_app_authresourceid() -> str:
+ """
+ This function returns apps registration in microsoft entra id
+ """
+ return os.environ.get("FunctionApp__AuthResourceId")
+
+
+def get_function_app_end_point() -> str:
+ """
+ This function returns function app endpoint
+ """
+ return os.environ.get("FunctionApp__Endpoint")
+
+def get_function_app_key() -> str:
+ """
+ This function returns function app key
+ """
+ return os.environ.get("FunctionApp__Key")
+
+def get_function_app_compass_function() -> str:
+ """
+ This function returns function app compass function name
+ """
+ return os.environ.get("FunctionApp__Compass__FunctionName")
+
+
+def get_function_app_pre_embedding_cleaner_function() -> str:
+ """
+ This function returns function app data cleanup function name
+ """
+ return os.environ.get("FunctionApp__PreEmbeddingCleaner__FunctionName")
+
+
+def get_function_app_adi_function() -> str:
+ """
+ This function returns function app adi name
+ """
+ return os.environ.get("FunctionApp__DocumentIntelligence__FunctionName")
+
+
+def get_function_app_custom_split_function() -> str:
+ """
+ This function returns function app adi name
+ """
+ return os.environ.get("FunctionApp__CustomTextSplit__FunctionName")
+
+
+def get_function_app_keyphrase_extractor_function() -> str:
+ """
+ This function returns function app keyphrase extractor name
+ """
+ return os.environ.get("FunctionApp__KeyphraseExtractor__FunctionName")
+
+
+def get_function_app_ocr_function() -> str:
+ """
+ This function returns function app ocr name
+ """
+ return os.environ.get("FunctionApp__Ocr__FunctionName")
+
+
+# search
+def get_search_endpoint() -> str:
+ """
+ This function returns azure ai search service endpoint
+ """
+ return os.environ.get("AIService__AzureSearchOptions__Endpoint")
+
+
+def get_search_user_assigned_identity() -> str:
+ """
+ This function returns azure ai search service endpoint
+ """
+ return os.environ.get("AIService__AzureSearchOptions__UserAssignedIdentity")
+
+
+def get_search_key(client) -> str:
+ """
+ This function returns azure ai search service admin key
+ """
+ search_service_key_secret_name = str(os.environ.get("AIService__AzureSearchOptions__name")) + "-PrimaryKey"
+ retrieved_secret = client.get_secret(search_service_key_secret_name)
+ return retrieved_secret.value
+
+def get_search_key_secret() -> str:
+ """
+ This function returns azure ai search service admin key
+ """
+ return os.environ.get("AIService__AzureSearchOptions__Key__Secret")
+
+
+def get_search_embedding_model_dimensions(indexer_type: IndexerType) -> str:
+ """
+ This function returns dimensions for embedding model
+ """
+
+ normalised_indexer_type = (
+ indexer_type.value.replace("-", " ").title().replace(" ", "")
+ )
+
+ return os.environ.get(
+ f"AIService__AzureSearchOptions__{normalised_indexer_type}__EmbeddingDimensions"
+ )
+
+def get_blob_connection_string() -> str:
+ """
+ This function returns azure blob storage connection string
+ """
+ return os.environ.get("StorageAccount__ConnectionString")
+
+def get_fq_blob_connection_string() -> str:
+ """
+ This function returns azure blob storage connection string
+ """
+ return os.environ.get("StorageAccount__FQEndpoint")
+
+
+def get_blob_container_name(indexer_type: str) -> str:
+ """
+ This function returns azure blob container name
+ """
+ normalised_indexer_type = (
+ indexer_type.value.replace("-", " ").title().replace(" ", "")
+ )
+ return os.environ.get(f"StorageAccount__{normalised_indexer_type}__Container")
+
+
+def get_custom_skill_function_url(skill_type: str):
+ """
+ Get the function app url that is hosting the custom skill
+ """
+ url = (
+ get_function_app_end_point()
+ + "/api/function_name?code="
+ + get_function_app_key()
+ )
+ if skill_type == "compass":
+ url = url.replace("function_name", get_function_app_compass_function())
+ elif skill_type == "pre_embedding_cleaner":
+ url = url.replace(
+ "function_name", get_function_app_pre_embedding_cleaner_function()
+ )
+ elif skill_type == "adi":
+ url = url.replace("function_name", get_function_app_adi_function())
+ elif skill_type == "split":
+ url = url.replace("function_name", get_function_app_custom_split_function())
+ elif skill_type == "keyphraseextraction":
+ url = url.replace(
+ "function_name", get_function_app_keyphrase_extractor_function()
+ )
+ elif skill_type == "ocr":
+ url = url.replace("function_name", get_function_app_ocr_function())
+
+ return url
diff --git a/aisearch-skillset/inquiry_document.py b/aisearch-skillset/inquiry_document.py
new file mode 100644
index 0000000..3f9dd0a
--- /dev/null
+++ b/aisearch-skillset/inquiry_document.py
@@ -0,0 +1,320 @@
+from azure.search.documents.indexes.models import (
+ SearchFieldDataType,
+ SearchField,
+ SearchableField,
+ SemanticField,
+ SemanticPrioritizedFields,
+ SemanticConfiguration,
+ SemanticSearch,
+ InputFieldMappingEntry,
+ SearchIndexer,
+ FieldMapping,
+ IndexingParameters,
+ IndexingParametersConfiguration,
+ BlobIndexerImageAction,
+ SearchIndexerIndexProjections,
+ SearchIndexerIndexProjectionSelector,
+ SearchIndexerIndexProjectionsParameters,
+ IndexProjectionMode,
+ SimpleField,
+ BlobIndexerDataToExtract,
+ IndexerExecutionEnvironment,
+ BlobIndexerPDFTextRotationAlgorithm,
+)
+from ai_search import AISearch
+from environment import (
+ get_search_embedding_model_dimensions,
+ IndexerType,
+)
+
+
+class InquiryDocumentAISearch(AISearch):
+ """This class is used to deploy the inquiry document index."""
+
+ def __init__(
+ self,
+ endpoint,
+ credential,
+ suffix=None,
+ rebuild=False,
+ enable_page_by_chunking=False,
+ ):
+ super().__init__(endpoint, credential, suffix, rebuild)
+
+ self.indexer_type = IndexerType.INQUIRY_DOCUMENT
+ if enable_page_by_chunking is not None:
+ self.enable_page_by_chunking = enable_page_by_chunking
+ else:
+ self.enable_page_by_chunking = False
+
+ # explicitly setting it to false no matter what output comes in
+ # might be removed later
+ # self.enable_page_by_chunking = False
+
+ def get_index_fields(self) -> list[SearchableField]:
+ """This function returns the index fields for inquiry document.
+
+ Returns:
+ list[SearchableField]: The index fields for inquiry document"""
+
+ fields = [
+ SimpleField(name="Id", type=SearchFieldDataType.String, filterable=True),
+ SearchableField(
+ name="Title", type=SearchFieldDataType.String, filterable=True
+ ),
+ SearchableField(
+ name="DealId",
+ type=SearchFieldDataType.String,
+ sortable=True,
+ filterable=True,
+ facetable=True,
+ ),
+ SearchableField(
+ name="OracleId",
+ type=SearchFieldDataType.String,
+ sortable=True,
+ filterable=True,
+ facetable=True,
+ ),
+ SearchableField(
+ name="ChunkId",
+ type=SearchFieldDataType.String,
+ key=True,
+ analyzer_name="keyword",
+ ),
+ SearchableField(
+ name="Chunk",
+ type=SearchFieldDataType.String,
+ sortable=False,
+ filterable=False,
+ facetable=False,
+ ),
+ SearchableField(
+ name="Section",
+ type=SearchFieldDataType.String,
+ collection=True,
+ ),
+ SearchField(
+ name="ChunkEmbedding",
+ type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
+ vector_search_dimensions=get_search_embedding_model_dimensions(
+ self.indexer_type
+ ),
+ vector_search_profile_name=self.vector_search_profile_name,
+ ),
+ SearchableField(
+ name="Keywords", type=SearchFieldDataType.String, collection=True
+ ),
+ SearchableField(
+ name="SourceUrl",
+ type=SearchFieldDataType.String,
+ sortable=True,
+ filterable=True,
+ facetable=True,
+ ),
+ SearchableField(
+ name="AdditionalMetadata",
+ type=SearchFieldDataType.String,
+ sortable=True,
+ filterable=True,
+ facetable=True,
+ ),
+ ]
+
+ if self.enable_page_by_chunking:
+ fields.extend(
+ [
+ SearchableField(
+ name="PageNumber",
+ type=SearchFieldDataType.Int64,
+ sortable=True,
+ filterable=True,
+ facetable=True,
+ )
+ ]
+ )
+
+ return fields
+
+ def get_semantic_search(self) -> SemanticSearch:
+ """This function returns the semantic search configuration for inquiry document
+
+ Returns:
+ SemanticSearch: The semantic search configuration"""
+
+ semantic_config = SemanticConfiguration(
+ name=self.semantic_config_name,
+ prioritized_fields=SemanticPrioritizedFields(
+ title_field=SemanticField(field_name="Title"),
+ content_fields=[SemanticField(field_name="Chunk")],
+ keywords_fields=[
+ SemanticField(field_name="Keywords"),
+ SemanticField(field_name="Section"),
+ ],
+ ),
+ )
+
+ semantic_search = SemanticSearch(configurations=[semantic_config])
+
+ return semantic_search
+
+ def get_skills(self):
+ """This function returns the skills for inquiry document"""
+
+ adi_skill = self.get_adi_skill(self.enable_page_by_chunking)
+
+ text_split_skill = self.get_text_split_skill(
+ "/document", "/document/extracted_content/content"
+ )
+
+ pre_embedding_cleaner_skill = self.get_pre_embedding_cleaner_skill(
+ "/document/pages/*", "/document/pages/*", self.enable_page_by_chunking
+ )
+
+ key_phrase_extraction_skill = self.get_key_phrase_extraction_skill(
+ "/document/pages/*", "/document/pages/*/cleaned_chunk"
+ )
+
+ embedding_skill = self.get_compass_vector_custom_skill(
+ "/document/pages/*", "/document/pages/*/cleaned_chunk"
+ )
+
+ if self.enable_page_by_chunking:
+ skills = [
+ adi_skill,
+ pre_embedding_cleaner_skill,
+ key_phrase_extraction_skill,
+ embedding_skill,
+ ]
+ else:
+ skills = [
+ adi_skill,
+ text_split_skill,
+ pre_embedding_cleaner_skill,
+ key_phrase_extraction_skill,
+ embedding_skill,
+ ]
+
+ return skills
+
+ def get_index_projections(self) -> SearchIndexerIndexProjections:
+ """This function returns the index projections for inquiry document."""
+ mappings =[
+ InputFieldMappingEntry(
+ name="Chunk", source="/document/pages/*/chunk"
+ ),
+ InputFieldMappingEntry(
+ name="ChunkEmbedding",
+ source="/document/pages/*/vector",
+ ),
+ InputFieldMappingEntry(
+ name="Title",
+ source="/document/Title"
+ ),
+ InputFieldMappingEntry(
+ name="DealId",
+ source="/document/DealId"
+ ),
+ InputFieldMappingEntry(
+ name="OracleId",
+ source="/document/OracleId"
+ ),
+ InputFieldMappingEntry(
+ name="SourceUrl",
+ source="/document/SourceUrl"
+ ),
+ InputFieldMappingEntry(
+ name="Keywords",
+ source="/document/pages/*/keywords"
+ ),
+ InputFieldMappingEntry(
+ name="AdditionalMetadata",
+ source="/document/AdditionalMetadata",
+ ),
+ InputFieldMappingEntry(
+ name="Section",
+ source="/document/pages/*/eachsection"
+ )
+ ]
+
+ if self.enable_page_by_chunking:
+ mappings.extend(
+ [
+ InputFieldMappingEntry(
+ name="PageNumber", source="/document/pages/*/page_no"
+ )
+ ]
+ )
+
+ index_projections = SearchIndexerIndexProjections(
+ selectors=[
+ SearchIndexerIndexProjectionSelector(
+ target_index_name=self.index_name,
+ parent_key_field_name="Id",
+ source_context="/document/pages/*",
+ mappings=mappings
+ ),
+ ],
+ parameters=SearchIndexerIndexProjectionsParameters(
+ projection_mode=IndexProjectionMode.SKIP_INDEXING_PARENT_DOCUMENTS
+ ),
+ )
+
+ return index_projections
+
+ def get_indexer(self) -> SearchIndexer:
+ """This function returns the indexer for inquiry document.
+
+ Returns:
+ SearchIndexer: The indexer for inquiry document"""
+ if self.test:
+ schedule = None
+ batch_size = 4
+ else:
+ schedule = {"interval": "PT15M"}
+ batch_size = 16
+
+ indexer_parameters = IndexingParameters(
+ batch_size=batch_size,
+ configuration=IndexingParametersConfiguration(
+ # image_action=BlobIndexerImageAction.GENERATE_NORMALIZED_IMAGE_PER_PAGE,
+ data_to_extract=BlobIndexerDataToExtract.ALL_METADATA,
+ query_timeout=None,
+ # allow_skillset_to_read_file_data=True,
+ execution_environment=IndexerExecutionEnvironment.PRIVATE,
+ # pdf_text_rotation_algorithm=BlobIndexerPDFTextRotationAlgorithm.DETECT_ANGLES,
+ fail_on_unprocessable_document=False,
+ fail_on_unsupported_content_type=False,
+ index_storage_metadata_only_for_oversized_documents=True,
+ indexed_file_name_extensions=".pdf,.pptx,.docx",
+ ),
+ max_failed_items=5,
+ )
+
+ indexer = SearchIndexer(
+ name=self.indexer_name,
+ description="Indexer to index documents and generate embeddings",
+ skillset_name=self.skillset_name,
+ target_index_name=self.index_name,
+ data_source_name=self.data_source_name,
+ schedule=schedule,
+ field_mappings=[
+ FieldMapping(
+ source_field_name="metadata_storage_name", target_field_name="Title"
+ ),
+ FieldMapping(source_field_name="Deal_ID", target_field_name="DealId"),
+ FieldMapping(
+ source_field_name="Oracle_ID", target_field_name="OracleId"
+ ),
+ FieldMapping(
+ source_field_name="SharePointUrl", target_field_name="SourceUrl"
+ ),
+ FieldMapping(
+ source_field_name="Additional_Metadata",
+ target_field_name="AdditionalMetadata",
+ ),
+ ],
+ parameters=indexer_parameters,
+ )
+
+ return indexer
diff --git a/function_apps/common/ai_search.py b/function_apps/common/ai_search.py
new file mode 100644
index 0000000..1bba829
--- /dev/null
+++ b/function_apps/common/ai_search.py
@@ -0,0 +1,127 @@
+from azure.search.documents.indexes.aio import SearchIndexerClient, SearchIndexClient
+from azure.search.documents.aio import SearchClient
+from azure.search.documents.indexes.models import SynonymMap
+from azure.identity import DefaultAzureCredential
+from azure.core.exceptions import HttpResponseError
+import logging
+import os
+from enum import Enum
+from openai import AsyncAzureOpenAI
+from azure.search.documents.models import VectorizedQuery
+
+
+class IndexerStatusEnum(Enum):
+ RETRIGGER = "RETRIGGER"
+ RUNNING = "RUNNING"
+ SUCCESS = "SUCCESS"
+
+
+class AISearchHelper:
+ def __init__(self):
+ self._client_id = os.environ["FunctionApp__ClientId"]
+
+ self._endpoint = os.environ["AIService__AzureSearchOptions__Endpoint"]
+
+ async def get_index_client(self):
+ credential = DefaultAzureCredential(managed_identity_client_id=self._client_id)
+
+ return SearchIndexClient(self._endpoint, credential)
+
+ async def get_indexer_client(self):
+ credential = DefaultAzureCredential(managed_identity_client_id=self._client_id)
+
+ return SearchIndexerClient(self._endpoint, credential)
+
+ async def get_search_client(self, index_name):
+ credential = DefaultAzureCredential(managed_identity_client_id=self._client_id)
+
+ return SearchClient(self._endpoint, index_name, credential)
+
+ async def upload_synonym_map(self, synonym_map_name: str, synonyms: str):
+ index_client = await self.get_index_client()
+ async with index_client:
+ try:
+ await index_client.delete_synonym_map(synonym_map_name)
+ except HttpResponseError as e:
+ logging.error("Unable to delete synonym map %s", e)
+
+ logging.info("Synonyms: %s", synonyms)
+ synonym_map = SynonymMap(name=synonym_map_name, synonyms=synonyms)
+ await index_client.create_synonym_map(synonym_map)
+
+ async def get_indexer_status(self, indexer_name):
+ indexer_client = await self.get_indexer_client()
+ async with indexer_client:
+ try:
+ status = await indexer_client.get_indexer_status(indexer_name)
+
+ last_execution_result = status.last_result
+
+ if last_execution_result.status == "inProgress":
+ return IndexerStatusEnum.RUNNING, last_execution_result.start_time
+ elif last_execution_result.status in ["success", "transientFailure"]:
+ return IndexerStatusEnum.SUCCESS, last_execution_result.start_time
+ else:
+ return IndexerStatusEnum.RETRIGGER, last_execution_result.start_time
+ except HttpResponseError as e:
+ logging.error("Unable to get indexer status %s", e)
+
+ async def trigger_indexer(self, indexer_name):
+ indexer_client = await self.get_indexer_client()
+ async with indexer_client:
+ try:
+ await indexer_client.run_indexer(indexer_name)
+ except HttpResponseError as e:
+ logging.error("Unable to run indexer %s", e)
+
+ async def search_index(
+ self, index_name, semantic_config, search_text, deal_id=None
+ ):
+ """Search the index using the provided search text."""
+ async with AsyncAzureOpenAI(
+ # This is the default and can be omitted
+ api_key=os.environ["AIService__Compass_Key"],
+ azure_endpoint=os.environ["AIService__Compass_Endpoint"],
+ api_version="2023-03-15-preview",
+ ) as open_ai_client:
+ embeddings = await open_ai_client.embeddings.create(
+ model=os.environ["AIService__Compass_Models__Embedding"],
+ input=search_text,
+ )
+
+ # Extract the embedding vector
+ embedding_vector = embeddings.data[0].embedding
+
+ vector_query = VectorizedQuery(
+ vector=embedding_vector,
+ k_nearest_neighbors=5,
+ fields="ChunkEmbedding",
+ )
+
+ if deal_id:
+ filter_expression = f"DealId eq '{deal_id}'"
+ else:
+ filter_expression = None
+
+ logging.info(f"Filter Expression: {filter_expression}")
+
+ search_client = await self.get_search_client(index_name)
+ async with search_client:
+ results = await search_client.search(
+ top=3,
+ query_type="semantic",
+ semantic_configuration_name=semantic_config,
+ search_text=search_text,
+ select="Title,Chunk",
+ vector_queries=[vector_query],
+ filter=filter_expression,
+ )
+
+ documents = [
+ document
+ async for result in results.by_page()
+ async for document in result
+ ]
+
+ logging.info(f"Documents: {documents}")
+ return documents
diff --git a/function_apps/indexer/adi_2_aisearch.py b/function_apps/indexer/adi_2_aisearch.py
new file mode 100644
index 0000000..e0542fb
--- /dev/null
+++ b/function_apps/indexer/adi_2_aisearch.py
@@ -0,0 +1,460 @@
+import base64
+from azure.core.credentials import AzureKeyCredential
+from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
+from azure.ai.documentintelligence.models import AnalyzeResult, ContentFormat
+import os
+import re
+import asyncio
+import fitz
+from PIL import Image
+import io
+import aiohttp
+import logging
+from common.storage_account import StorageAccountHelper
+import concurrent.futures
+import json
+
+
+def crop_image_from_pdf_page(pdf_path, page_number, bounding_box):
+ """
+ Crops a region from a given page in a PDF and returns it as an image.
+
+ :param pdf_path: Path to the PDF file.
+ :param page_number: The page number to crop from (0-indexed).
+ :param bounding_box: A tuple of (x0, y0, x1, y1) coordinates for the bounding box.
+ :return: A PIL Image of the cropped area.
+ """
+ doc = fitz.open(pdf_path)
+ page = doc.load_page(page_number)
+
+ # Cropping the page. The rect requires the coordinates in the format (x0, y0, x1, y1).
+ bbx = [x * 72 for x in bounding_box]
+ rect = fitz.Rect(bbx)
+ pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72), clip=rect)
+
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
+
+ doc.close()
+ return img
+
+
+def clean_adi_markdown(markdown_text: str, page_no:int,remove_irrelevant_figures=False):
+ """Clean Markdown text extracted by the Azure Document Intelligence service.
+
+ Args:
+ -----
+ markdown_text (str): The original Markdown text.
+ remove_irrelevant_figures (bool): Whether to remove all figures or just irrelevant ones.
+
+ Returns:
+ --------
+ str: The cleaned Markdown text.
+ """
+
+ # # Remove the page number comment
+ # page_number_pattern = r""
+ # cleaned_text = re.sub(page_number_pattern, "", markdown_text)
+
+ # # Replace the page header comment with its content
+ # page_header_pattern = r""
+ # cleaned_text = re.sub(
+ # page_header_pattern, lambda match: match.group(1), cleaned_text
+ # )
+
+ # # Replace the page footer comment with its content
+ # page_footer_pattern = r""
+ # cleaned_text = re.sub(
+ # page_footer_pattern, lambda match: match.group(1), cleaned_text
+ # )
+ output_dict = {}
+ comment_patterns = r"||"
+ cleaned_text = re.sub(comment_patterns, "", markdown_text, flags=re.DOTALL)
+
+ combined_pattern = r'(.*?)\n===|\n## ?(.*?)\n|\n### ?(.*?)\n'
+ doc_metadata = re.findall(combined_pattern, cleaned_text, re.DOTALL)
+ doc_metadata = [match for group in doc_metadata for match in group if match]
+
+
+ if remove_irrelevant_figures:
+ # Remove irrelevant figures
+ irrelevant_figure_pattern = (
+ r".*?.*?\s*"
+ )
+ cleaned_text = re.sub(
+ irrelevant_figure_pattern, "", cleaned_text, flags=re.DOTALL
+ )
+
+ # Replace ':selected:' with a new line
+ cleaned_text = re.sub(r":(selected|unselected):", "\n", cleaned_text)
+ output_dict['content'] = cleaned_text
+ output_dict['section'] = doc_metadata
+
+ # add page number when chunk by page is enabled
+ if page_no> -1:
+ output_dict['page_number'] = page_no
+
+ return output_dict
+
+
+def update_figure_description(md_content, img_description, idx):
+ """
+ Updates the figure description in the Markdown content.
+
+ Args:
+ md_content (str): The original Markdown content.
+ img_description (str): The new description for the image.
+ idx (int): The index of the figure.
+
+ Returns:
+ str: The updated Markdown content with the new figure description.
+ """
+
+ # The substring you're looking for
+ start_substring = f""
+ end_substring = ""
+ new_string = f''
+
+ new_md_content = md_content
+ # Find the start and end indices of the part to replace
+ start_index = md_content.find(start_substring)
+ if start_index != -1: # if start_substring is found
+ start_index += len(
+ start_substring
+ ) # move the index to the end of start_substring
+ end_index = md_content.find(end_substring, start_index)
+ if end_index != -1: # if end_substring is found
+ # Replace the old string with the new string
+ new_md_content = (
+ md_content[:start_index] + new_string + md_content[end_index:]
+ )
+
+ return new_md_content
+
+
+async def understand_image_with_vlm(image_base64):
+ """
+ Sends a base64-encoded image to a VLM (Vision Language Model) endpoint for financial analysis.
+
+ Args:
+ image_base64 (str): The base64-encoded string representation of the image.
+
+ Returns:
+ str: The response from the VLM, which is either a financial analysis or a statement indicating the image is not useful.
+ """
+ # prompt = "Describe the image ONLY IF it is useful for financial analysis. Otherwise, say 'NOT USEFUL IMAGE' and NOTHING ELSE. "
+ prompt = "Perform financial analysis of the image ONLY IF the image is of graph, chart, flowchart or table. Otherwise, say 'NOT USEFUL IMAGE' and NOTHING ELSE. "
+ headers = {"Content-Type": "application/json"}
+ data = {"prompt": prompt, "image": image_base64}
+ vlm_endpoint = os.environ["AIServices__VLM__Endpoint"]
+ async with aiohttp.ClientSession() as session:
+ async with session.post(
+ vlm_endpoint, headers=headers, json=data, timeout=30
+ ) as response:
+ response_data = await response.json()
+ response_text = response_data["response"].split("")[0]
+
+ if (
+ "not useful for financial analysis" in response_text
+ or "NOT USEFUL IMAGE" in response_text
+ ):
+ return "Irrelevant Image"
+ else:
+ return response_text
+
+
+def pil_image_to_base64(image, image_format="JPEG"):
+ """
+ Converts a PIL image to a base64-encoded string.
+
+ Args:
+ image (PIL.Image.Image): The image to be converted.
+ image_format (str): The format to save the image in. Defaults to "JPEG".
+
+ Returns:
+ str: The base64-encoded string representation of the image.
+ """
+ if image.mode == "RGBA" and image_format == "JPEG":
+ image = image.convert("RGB")
+ buffered = io.BytesIO()
+ image.save(buffered, format=image_format)
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
+async def process_figures_from_extracted_content(
+ file_path: str, markdown_content: str, figures: list, page_number: None | int = None
+) -> str:
+ """Process the figures extracted from the content using ADI and send them for analysis.
+
+ Args:
+ -----
+ file_path (str): The path to the PDF file.
+ markdown_content (str): The extracted content in Markdown format.
+ figures (list): The list of figures extracted by the Azure Document Intelligence service.
+ page_number (int): The page number to process. If None, all pages are processed.
+
+ Returns:
+ --------
+ str: The updated Markdown content with the figure descriptions."""
+ for idx, figure in enumerate(figures):
+ img_description = ""
+ logging.debug(f"Figure #{idx} has the following spans: {figure.spans}")
+
+ caption_region = figure.caption.bounding_regions if figure.caption else []
+ for region in figure.bounding_regions:
+ # Skip the region if it is not on the specified page
+ if page_number is not None and region.page_number != page_number:
+ continue
+
+ if region not in caption_region:
+ # To learn more about bounding regions, see https://aka.ms/bounding-region
+ bounding_box = (
+ region.polygon[0], # x0 (left)
+ region.polygon[1], # y0 (top)
+ region.polygon[4], # x1 (right)
+ region.polygon[5], # y1 (bottom)
+ )
+ cropped_image = crop_image_from_pdf_page(
+ file_path, region.page_number - 1, bounding_box
+ ) # page_number is 1-indexed3
+
+ image_base64 = pil_image_to_base64(cropped_image)
+
+ img_description += await understand_image_with_vlm(image_base64)
+ logging.info(f"\tDescription of figure {idx}: {img_description}")
+
+ markdown_content = update_figure_description(
+ markdown_content, img_description, idx
+ )
+
+ return markdown_content
+
+
+def create_page_wise_content(result: AnalyzeResult) -> list:
+ """Create a list of page-wise content extracted by the Azure Document Intelligence service.
+
+ Args:
+ -----
+ result (AnalyzeResult): The result of the document analysis.
+
+ Returns:
+ --------
+ list: A list of page-wise content extracted by the Azure Document Intelligence service.
+ """
+
+ page_wise_content = []
+ page_numbers = []
+ page_number = 0
+ for page in result.pages:
+ page_content = result.content[
+ page.spans[0]["offset"] : page.spans[0]["offset"] + page.spans[0]["length"]
+ ]
+ page_wise_content.append(page_content)
+ page_number+=1
+ page_numbers.append(page_number)
+
+ return page_wise_content,page_numbers
+
+
+async def analyse_document(file_path: str) -> AnalyzeResult:
+ """Analyse a document using the Azure Document Intelligence service.
+
+ Args:
+ -----
+ file_path (str): The path to the document to analyse.
+
+ Returns:
+ --------
+ AnalyzeResult: The result of the document analysis."""
+ with open(file_path, "rb") as f:
+ file_read = f.read()
+ # base64_encoded_file = base64.b64encode(file_read).decode("utf-8")
+
+ async with DocumentIntelligenceClient(
+ endpoint=os.environ["AIService__Services__Endpoint"],
+ credential=AzureKeyCredential(os.environ["AIService__Services__Key"]),
+ ) as document_intelligence_client:
+ poller = await document_intelligence_client.begin_analyze_document(
+ model_id="prebuilt-layout",
+ analyze_request=file_read,
+ output_content_format=ContentFormat.MARKDOWN,
+ content_type="application/octet-stream",
+ )
+
+ result = await poller.result()
+
+ if result is None or result.content is None or result.pages is None:
+ raise ValueError(
+ "Failed to analyze the document with Azure Document Intelligence."
+ )
+
+ return result
+
+
+async def process_adi_2_ai_search(record: dict, chunk_by_page: bool = False) -> dict:
+ logging.info("Python HTTP trigger function processed a request.")
+
+ storage_account_helper = StorageAccountHelper()
+
+ try:
+ source = record["data"]["source"]
+ logging.info(f"Request Body: {record}")
+ except KeyError:
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": "Failed to extract data with ADI. Pass a valid source in the request body.",
+ }
+ ],
+ "warnings": None,
+ }
+ else:
+ logging.info(f"Source: {source}")
+
+ try:
+ source_parts = source.split("/")
+ blob = "/".join(source_parts[4:])
+ logging.info(f"Blob: {blob}")
+
+ container = source_parts[3]
+
+ file_extension = blob.split(".")[-1]
+ target_file_name = f"{record['recordId']}.{file_extension}"
+
+ temp_file_path, _ = await storage_account_helper.download_blob_to_temp_dir(
+ blob, container, target_file_name
+ )
+ logging.info(temp_file_path)
+ except Exception as e:
+ logging.error(f"Failed to download the blob: {e}")
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": f"Failed to download the blob. Check the source and try again. {e}",
+ }
+ ],
+ "warnings": None,
+ }
+
+ try:
+ result = await analyse_document(temp_file_path)
+ except Exception as e:
+ logging.error(e)
+ logging.info("Sleeping for 10 seconds and retrying")
+ await asyncio.sleep(10)
+ try:
+ result = await analyse_document(temp_file_path)
+ except ValueError as inner_e:
+ logging.error(inner_e)
+ logging.error(
+ f"Failed to analyze the document with Azure Document Intelligence: {e}"
+ )
+ logging.error(
+ "Failed to analyse %s with Azure Document Intelligence.", blob
+ )
+ await storage_account_helper.add_metadata_to_blob(
+ blob, container, {"AzureSearch_Skip": "true"}
+ )
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": f"Failed to analyze the document with Azure Document Intelligence. This blob will now be skipped {inner_e}",
+ }
+ ],
+ "warnings": None,
+ }
+ except Exception as inner_e:
+ logging.error(inner_e)
+ logging.error(
+ "Failed to analyse %s with Azure Document Intelligence.", blob
+ )
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": f"Failed to analyze the document with Azure Document Intelligence. Check the logs and try again. {inner_e}",
+ }
+ ],
+ "warnings": None,
+ }
+
+ try:
+ if chunk_by_page:
+ markdown_content,page_no = create_page_wise_content(result)
+ else:
+ markdown_content = result.content
+
+ # Remove this line when VLM is ready
+ content_with_figures = markdown_content
+
+ # if chunk_by_page:
+ # tasks = [
+ # process_figures_from_extracted_content(
+ # temp_file_path, page_content, result.figures, page_number=idx
+ # )
+ # for idx, page_content in enumerate(markdown_content)
+ # ]
+ # content_with_figures = await asyncio.gather(*tasks)
+ # else:
+ # content_with_figures = await process_figures_from_extracted_content(
+ # temp_file_path, markdown_content, result.figures
+ # )
+
+ # Remove remove_irrelevant_figures=True when VLM is ready
+ if chunk_by_page:
+ cleaned_result = []
+ with concurrent.futures.ProcessPoolExecutor() as executor:
+ results = executor.map(clean_adi_markdown,content_with_figures, page_no,[False] * len(content_with_figures))
+
+ for cleaned_content in results:
+ cleaned_result.append(cleaned_content)
+
+ # with concurrent.futures.ProcessPoolExecutor() as executor:
+ # futures = {
+ # executor.submit(
+ # clean_adi_markdown, page_content, False
+ # ): page_content
+ # for page_content in content_with_figures
+ # }
+ # for future in concurrent.futures.as_completed(futures):
+ # cleaned_result.append(future.result())
+ else:
+ cleaned_result = clean_adi_markdown(
+ content_with_figures, page_no=-1,remove_irrelevant_figures=False
+ )
+ except Exception as e:
+ logging.error(e)
+ logging.error(f"Failed to process the extracted content: {e}")
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": f"Failed to process the extracted content. Check the logs and try again. {e}",
+ }
+ ],
+ "warnings": None,
+ }
+
+ logging.info("Document Extracted")
+ logging.info(f"Result: {cleaned_result}")
+
+ src = {
+ "recordId": record["recordId"],
+ "data": {"extracted_content": cleaned_result},
+ }
+
+ json_str = json.dumps(src, indent=4)
+
+ logging.info(f"final output: {json_str}")
+
+ return {
+ "recordId": record["recordId"],
+ "data": {"extracted_content": cleaned_result},
+ }
diff --git a/function_apps/indexer/function_app.py b/function_apps/indexer/function_app.py
new file mode 100644
index 0000000..12d5d5b
--- /dev/null
+++ b/function_apps/indexer/function_app.py
@@ -0,0 +1,296 @@
+from datetime import datetime, timedelta, timezone
+import azure.functions as func
+import logging
+import json
+import asyncio
+
+from adi_2_ai_search import process_adi_2_ai_search
+from common.service_bus import ServiceBusHelper
+from pre_embedding_cleaner import process_pre_embedding_cleaner
+
+from text_split import process_text_split
+from ai_search_2_compass import process_ai_search_2_compass
+from key_phrase_extraction import process_key_phrase_extraction
+from ocr import process_ocr
+from pending_index_completion import process_pending_index_completion
+from pending_index_trigger import process_pending_index_trigger
+
+from common.payloads.pending_index_trigger import PendingIndexTriggerPayload
+
+from common.payloads.header import TaskEnum
+
+logging.basicConfig(level=logging.INFO)
+app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION)
+
+
+@app.route(route="text_split", methods=[func.HttpMethod.POST])
+async def text_split(req: func.HttpRequest) -> func.HttpResponse:
+ """Extract the content from a document using ADI."""
+
+ try:
+ req_body = req.get_json()
+ values = req_body.get("values")
+ text_split_config = req.headers
+ except ValueError:
+ return func.HttpResponse(
+ "Please valid Custom Skill Payload in the request body", status_code=400
+ )
+ else:
+ logging.debug(f"Input Values: {values}")
+
+ record_tasks = []
+
+ for value in values:
+ record_tasks.append(
+ asyncio.create_task(process_text_split(value, text_split_config))
+ )
+
+ results = await asyncio.gather(*record_tasks)
+ logging.debug(f"Results: {results}")
+
+ return func.HttpResponse(
+ json.dumps({"values": results}),
+ status_code=200,
+ mimetype="application/json",
+ )
+
+
+@app.route(route="ai_search_2_compass", methods=[func.HttpMethod.POST])
+async def ai_search_2_compass(req: func.HttpRequest) -> func.HttpResponse:
+ logging.info("Python HTTP trigger function processed a request.")
+
+ """HTTP trigger for AI Search 2 Compass function.
+
+ Args:
+ req (func.HttpRequest): The HTTP request object.
+
+ Returns:
+ func.HttpResponse: The HTTP response object."""
+ logging.info("Python HTTP trigger function processed a request.")
+
+ try:
+ req_body = req.get_json()
+ values = req_body.get("values")
+ except ValueError:
+ return func.HttpResponse(
+ "Please valid Custom Skill Payload in the request body", status_code=400
+ )
+ else:
+ logging.debug("Input Values: %s", values)
+
+ record_tasks = []
+
+ for value in values:
+ record_tasks.append(asyncio.create_task(process_ai_search_2_compass(value)))
+
+ results = await asyncio.gather(*record_tasks)
+ logging.debug("Results: %s", results)
+ vectorised_tasks = {"values": results}
+
+ return func.HttpResponse(
+ json.dumps(vectorised_tasks), status_code=200, mimetype="application/json"
+ )
+
+
+@app.route(route="adi_2_ai_search", methods=[func.HttpMethod.POST])
+async def adi_2_ai_search(req: func.HttpRequest) -> func.HttpResponse:
+ """Extract the content from a document using ADI."""
+
+ try:
+ req_body = req.get_json()
+ values = req_body.get("values")
+ adi_config = req.headers
+
+ chunk_by_page = adi_config.get("chunk_by_page", "False").lower() == "true"
+ logging.info(f"Chunk by Page: {chunk_by_page}")
+ except ValueError:
+ return func.HttpResponse(
+ "Please valid Custom Skill Payload in the request body", status_code=400
+ )
+ else:
+ logging.debug("Input Values: %s", values)
+
+ record_tasks = []
+
+ for value in values:
+ record_tasks.append(
+ asyncio.create_task(
+ process_adi_2_ai_search(value, chunk_by_page=chunk_by_page)
+ )
+ )
+
+ results = await asyncio.gather(*record_tasks)
+ logging.debug("Results: %s", results)
+
+ return func.HttpResponse(
+ json.dumps({"values": results}),
+ status_code=200,
+ mimetype="application/json",
+ )
+
+
+@app.route(route="pre_embedding_cleaner", methods=[func.HttpMethod.POST])
+async def pre_embedding_cleaner(req: func.HttpRequest) -> func.HttpResponse:
+ """HTTP trigger for data cleanup function.
+
+ Args:
+ req (func.HttpRequest): The HTTP request object.
+
+ Returns:
+ func.HttpResponse: The HTTP response object."""
+ logging.info("Python HTTP trigger data cleanup function processed a request.")
+
+ try:
+ req_body = req.get_json()
+ values = req_body.get("values")
+ except ValueError:
+ return func.HttpResponse(
+ "Please valid Custom Skill Payload in the request body", status_code=400
+ )
+ else:
+ logging.debug("Input Values: %s", values)
+
+ record_tasks = []
+
+ for value in values:
+ record_tasks.append(
+ asyncio.create_task(process_pre_embedding_cleaner(value))
+ )
+
+ results = await asyncio.gather(*record_tasks)
+ logging.debug("Results: %s", results)
+ cleaned_tasks = {"values": results}
+
+ return func.HttpResponse(
+ json.dumps(cleaned_tasks), status_code=200, mimetype="application/json"
+ )
+
+
+@app.route(route="keyphrase_extractor", methods=[func.HttpMethod.POST])
+async def keyphrase_extractor(req: func.HttpRequest) -> func.HttpResponse:
+ """HTTP trigger for data cleanup function.
+
+ Args:
+ req (func.HttpRequest): The HTTP request object.
+
+ Returns:
+ func.HttpResponse: The HTTP response object."""
+ logging.info("Python HTTP trigger data cleanup function processed a request.")
+
+ try:
+ req_body = req.get_json()
+ values = req_body.get("values")
+ logging.info(req_body)
+ except ValueError:
+ return func.HttpResponse(
+ "Please valid Custom Skill Payload in the request body", status_code=400
+ )
+ else:
+ logging.debug("Input Values: %s", values)
+
+ record_tasks = []
+
+ for value in values:
+ record_tasks.append(
+ asyncio.create_task(process_key_phrase_extraction(value))
+ )
+
+ results = await asyncio.gather(*record_tasks)
+ logging.debug("Results: %s", results)
+ cleaned_tasks = {"values": results}
+
+ return func.HttpResponse(
+ json.dumps(cleaned_tasks), status_code=200, mimetype="application/json"
+ )
+
+
+@app.route(route="ocr", methods=[func.HttpMethod.POST])
+async def ocr(req: func.HttpRequest) -> func.HttpResponse:
+ """HTTP trigger for data cleanup function.
+
+ Args:
+ req (func.HttpRequest): The HTTP request object.
+
+ Returns:
+ func.HttpResponse: The HTTP response object."""
+ logging.info("Python HTTP trigger data cleanup function processed a request.")
+
+ try:
+ req_body = req.get_json()
+ values = req_body.get("values")
+ except ValueError:
+ return func.HttpResponse(
+ "Please valid Custom Skill Payload in the request body", status_code=400
+ )
+ else:
+ logging.debug("Input Values: %s", values)
+
+ record_tasks = []
+
+ for value in values:
+ record_tasks.append(asyncio.create_task(process_ocr(value)))
+
+ results = await asyncio.gather(*record_tasks)
+ logging.debug("Results: %s", results)
+ cleaned_tasks = {"values": results}
+
+ return func.HttpResponse(
+ json.dumps(cleaned_tasks), status_code=200, mimetype="application/json"
+ )
+
+
+@app.service_bus_queue_trigger(
+ arg_name="msg",
+ queue_name="pending_index_trigger",
+ connection="ServiceBusTrigger",
+)
+async def pending_index_trigger(msg: func.ServiceBusMessage):
+ logging.info(
+ f"trigger-indexer: Python ServiceBus queue trigger processed message: {msg}"
+ )
+ try:
+ payload = PendingIndexTriggerPayload.from_service_bus_message(msg)
+ await process_pending_index_trigger(payload)
+ except ValueError as ve:
+ logging.error(f"ValueError: {ve}")
+ except Exception as e:
+ logging.error(f"Error processing ServiceBus message: {e}")
+
+ if "On-demand indexer invocation is permitted every 180 seconds" in str(e):
+ logging.warning(
+ f"Indexer invocation limit reached: {e}. Scheduling a retry."
+ )
+ service_bus_helper = ServiceBusHelper()
+ message = PendingIndexTriggerPayload(
+ header=payload.header, body=payload.body, errors=[]
+ )
+ queue = TaskEnum.PENDING_INDEX_TRIGGER.value
+ minutes = 2 ** (11 - payload.header.retries_remaining)
+ enqueue_time = datetime.now(timezone.utc) + timedelta(minutes=minutes)
+ await service_bus_helper.send_message_to_service_bus_queue(
+ queue, message, enqueue_time=enqueue_time
+ )
+ else:
+ raise e
+
+
+@app.service_bus_queue_trigger(
+ arg_name="msg",
+ queue_name="pending_index_completion",
+ connection="ServiceBusTrigger",
+)
+async def pending_index_completion(msg: func.ServiceBusMessage):
+ logging.info(
+ f"indexer-polling-trigger: Python ServiceBus queue trigger processed message: {msg}"
+ )
+
+ try:
+ payload = PendingIndexTriggerPayload.from_service_bus_message(msg)
+ await process_pending_index_completion(payload)
+ except ValueError as ve:
+ logging.error(f"ValueError: {ve}")
+ except Exception as e:
+ logging.error(f"Error processing ServiceBus message: {e}")
+ if "The operation has timed out" in str(e):
+ logging.error("The operation has timed out.")
+ raise e
diff --git a/function_apps/indexer/key_phrase_extraction.py b/function_apps/indexer/key_phrase_extraction.py
new file mode 100644
index 0000000..c6ab40e
--- /dev/null
+++ b/function_apps/indexer/key_phrase_extraction.py
@@ -0,0 +1,112 @@
+import logging
+import json
+import os
+from azure.ai.textanalytics.aio import TextAnalyticsClient
+from azure.core.exceptions import HttpResponseError
+from azure.core.credentials import AzureKeyCredential
+import asyncio
+
+MAX_TEXT_ELEMENTS = 5120
+
+def split_document(document, max_size):
+ """Split a document into chunks of max_size."""
+ return [document[i:i + max_size] for i in range(0, len(document), max_size)]
+
+async def extract_key_phrases_from_text(data: list[str],max_key_phrase_count:int) -> list[str]:
+ logging.info("Python HTTP trigger function processed a request.")
+
+ max_retries = 5
+ key_phrase_list = []
+ text_analytics_client = TextAnalyticsClient(
+ endpoint=os.environ["AIService__Services__Endpoint"],
+ credential=AzureKeyCredential(os.environ["AIService__Services__Key"]),
+ )
+
+ try:
+ async with text_analytics_client:
+ retries = 0
+ while retries < max_retries:
+ try:
+ # Split large documents
+ split_documents = []
+ for doc in data:
+ if len(doc) > MAX_TEXT_ELEMENTS:
+ split_documents.extend(split_document(doc, MAX_TEXT_ELEMENTS))
+ else:
+ split_documents.append(doc)
+ result = await text_analytics_client.extract_key_phrases(split_documents)
+ for idx,doc in enumerate(result):
+ if not doc.is_error:
+ key_phrase_list.extend(doc.key_phrases[:max_key_phrase_count])
+ else:
+ raise Exception(f"Document {idx} error: {doc.error}")
+ break # Exit the loop if the request is successful
+ except HttpResponseError as e:
+ if e.status_code == 429: # Rate limiting error
+ retries += 1
+ wait_time = 2 ** retries # Exponential backoff
+ print(f"Rate limit exceeded. Retrying in {wait_time} seconds...")
+ await asyncio.sleep(wait_time)
+ else:
+ raise Exception(f"An error occurred: {e}")
+ except Exception as e:
+ raise Exception(f"An error occurred: {e}")
+
+ return key_phrase_list
+
+
+async def process_key_phrase_extraction(record: dict,max_key_phrase_count:int =5 ) -> dict:
+ """Extract key phrases using azure ai services.
+
+ Args:
+ record (dict): The record to process.
+ max_key_phrase_count(int): no of keywords to return
+
+ Returns:
+ dict: extracted key words."""
+
+ try:
+ json_str = json.dumps(record, indent=4)
+
+ logging.info(f"key phrase extraction Input: {json_str}")
+ extracted_record = {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": None,
+ "warnings": None,
+ }
+ extracted_record["data"]["keyPhrases"] = await extract_key_phrases_from_text(
+ [record["data"]["text"]],max_key_phrase_count
+ )
+ except Exception as e:
+ logging.error("key phrase extraction Error: %s", e)
+ await asyncio.sleep(10)
+ try:
+ extracted_record = {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": None,
+ "warnings": None,
+ }
+ extracted_record["data"][
+ "keyPhrases"
+ ] = await extract_key_phrases_from_text([record["data"]["text"]],max_key_phrase_count)
+ except Exception as inner_e:
+ logging.error("key phrase extraction Error: %s", inner_e)
+ logging.error(
+ "Failed to extract key phrase. Check function app logs for more details of exact failure."
+ )
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": "Failed to extract key phrase. Check function app logs for more details of exact failure."
+ }
+ ],
+ "warnings": None,
+ }
+ json_str = json.dumps(extracted_record, indent=4)
+
+ logging.info(f"key phrase extraction output: {json_str}")
+ return extracted_record
diff --git a/function_apps/indexer/pre_embedding_cleaner.py b/function_apps/indexer/pre_embedding_cleaner.py
new file mode 100644
index 0000000..2fdf87a
--- /dev/null
+++ b/function_apps/indexer/pre_embedding_cleaner.py
@@ -0,0 +1,144 @@
+import logging
+import json
+import string
+import nltk
+import re
+from nltk.tokenize import word_tokenize
+
+nltk.download("punkt")
+nltk.download("stopwords")
+
+import re
+
+# Configure logging
+logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
+
+def get_section(cleaned_text:str) -> list:
+ """
+ Returns the section details from the content
+
+ Args:
+ cleaned_text: The input text
+
+ Returns:
+ list: The sections related to text
+
+ """
+ combined_pattern = r'(.*?)\n===|\n## ?(.*?)\n|\n### ?(.*?)\n'
+ doc_metadata = re.findall(combined_pattern, cleaned_text, re.DOTALL)
+ doc_metadata = [match for group in doc_metadata for match in group if match]
+ return doc_metadata
+
+def remove_markdown_tags(text:str, tag_patterns:dict) ->str:
+ """
+ Remove specified Markdown tags from the text, keeping the contents of the tags.
+
+ Args:
+ text: The input text containing Markdown tags.
+ tag_patterns: A dictionary where keys are tags and values are their specific patterns.
+
+ Returns:
+ str: The text with specified tags removed.
+ """
+ try:
+ for tag, pattern in tag_patterns.items():
+ try:
+ # Replace the tags using the specific pattern, keeping the content inside the tags
+ text = re.sub(pattern, r'\1', text, flags=re.DOTALL)
+ except re.error as e:
+ logging.error(f"Regex error for tag '{tag}': {e}")
+ except Exception as e:
+ logging.error(f"An error occurred in remove_markdown_tags: {e}")
+ return text
+
+def clean_text(src_text: str) -> str:
+ """This function performs following cleanup activities on the text, remove all unicode characters
+ remove line spacing,remove stop words, normalize characters
+
+ Args:
+ src_text (str): The text to cleanup.
+
+ Returns:
+ str: The clean text."""
+
+ try:
+ # Define specific patterns for each tag
+ tag_patterns = {
+ "figurecontent": r"",
+ "figure": r"(.*?)",
+ "figures": r"\(figures/\d+\)(.*?)\(figures/\d+\)",
+ "figcaption": r"(.*?)",
+ }
+ cleaned_text = remove_markdown_tags(src_text, tag_patterns)
+
+ # remove line breaks
+ cleaned_text = re.sub(r"\n", "", cleaned_text)
+
+ # remove stopwords
+ tokens = word_tokenize(cleaned_text, "english")
+ stop_words = nltk.corpus.stopwords.words("english")
+ filtered_tokens = [word for word in tokens if word not in stop_words]
+ cleaned_text = " ".join(filtered_tokens)
+
+ # remove special characters
+ cleaned_text = re.sub(r"[^a-zA-Z\s]", "", cleaned_text)
+
+ # remove extra white spaces
+ cleaned_text = " ".join([word for word in cleaned_text.split()])
+
+ # case normalization
+ cleaned_text = cleaned_text.lower()
+ except Exception as e:
+ logging.error(f"An error occurred in clean_text: {e}")
+ return ""
+ return cleaned_text
+
+
+async def process_pre_embedding_cleaner(record: dict) -> dict:
+ """Cleanup the data using standard python libraries.
+
+ Args:
+ record (dict): The record to cleanup.
+
+ Returns:
+ dict: The clean record."""
+
+ try:
+ json_str = json.dumps(record, indent=4)
+
+ logging.info(f"embedding cleaner Input: {json_str}")
+
+ cleaned_record = {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": None,
+ "warnings": None,
+ }
+
+ # scenarios when page by chunking is enabled
+ if isinstance(record["data"]["chunk"],dict):
+ cleaned_record["data"]["cleaned_chunk"] = clean_text(record["data"]["chunk"]["content"])
+ cleaned_record["data"]["chunk"] = record["data"]["chunk"]["content"]
+ cleaned_record["data"]["section"] = record["data"]["chunk"]["section"]
+ cleaned_record["data"]["page_number"] = record["data"]["chunk"]["page_number"]
+ else:
+ cleaned_record["data"]["cleaned_chunk"] = clean_text(record["data"]["chunk"])
+ cleaned_record["data"]["chunk"] = record["data"]["chunk"]
+ cleaned_record["data"]["section"] = get_section(record["data"]["chunk"])
+
+ except Exception as e:
+ logging.error("string cleanup Error: %s", e)
+ return {
+ "recordId": record["recordId"],
+ "data": {},
+ "errors": [
+ {
+ "message": "Failed to cleanup data. Check function app logs for more details of exact failure."
+ }
+ ],
+ "warnings": None,
+ }
+ json_str = json.dumps(cleaned_record, indent=4)
+
+ logging.info(f"embedding cleaner output: {json_str}")
+ return cleaned_record
diff --git a/function_apps/indexer/requirements.txt b/function_apps/indexer/requirements.txt
new file mode 100644
index 0000000..48c9837
--- /dev/null
+++ b/function_apps/indexer/requirements.txt
@@ -0,0 +1,26 @@
+# DO NOT include azure-functions-worker in this file
+# The Python Worker is managed by Azure Functions platform
+# Manually managing azure-functions-worker may cause unexpected issues
+python-dotenv
+azure-functions
+openai
+azure-storage-blob
+pandas
+azure-identity
+openpyxl
+regex
+nltk==3.8.1
+bs4
+azure-search
+azure-search-documents
+azure-ai-documentintelligence
+azure-ai-textanalytics
+azure-ai-vision-imageanalysis
+PyMuPDF
+pillow
+torch
+aiohttp
+spacy==3.7.5
+transformers
+scikit-learn
+en-core-web-md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.7.1/en_core_web_md-3.7.1.tar.gz
From 8d177a003b33b6b5433a91797149822fb0bd121f Mon Sep 17 00:00:00 2001
From: priyal1508 <54278892+priyal1508@users.noreply.github.com>
Date: Thu, 5 Sep 2024 19:11:57 +0530
Subject: [PATCH 2/4] changes for common scripts
---
aisearch-skillset/deploy.py | 13 ------
aisearch-skillset/inquiry_document.py | 16 ++++----
function_apps/common/payloads/error.py | 20 ++++++++++
function_apps/common/payloads/header.py | 40 +++++++++++++++++++
function_apps/common/payloads/payload.py | 20 ++++++++++
.../payloads/pending_index_completion.py | 40 +++++++++++++++++++
.../common/payloads/pennding_index_trigger.py | 32 +++++++++++++++
.../indexer/pending_index_completion.py | 0
8 files changed, 160 insertions(+), 21 deletions(-)
create mode 100644 function_apps/common/payloads/error.py
create mode 100644 function_apps/common/payloads/header.py
create mode 100644 function_apps/common/payloads/payload.py
create mode 100644 function_apps/common/payloads/pending_index_completion.py
create mode 100644 function_apps/common/payloads/pennding_index_trigger.py
create mode 100644 function_apps/indexer/pending_index_completion.py
diff --git a/aisearch-skillset/deploy.py b/aisearch-skillset/deploy.py
index d98e099..1b2190b 100644
--- a/aisearch-skillset/deploy.py
+++ b/aisearch-skillset/deploy.py
@@ -30,19 +30,6 @@ def main(args):
rebuild=args.rebuild,
enable_page_by_chunking=args.enable_page_chunking
)
- elif args.indexer_type == "summary":
- # Deploy the summarises index
- index_config = SummaryDocumentAISearch(
- endpoint=endpoint,
- credential=credential,
- suffix=args.suffix,
- rebuild=args.rebuild,
- enable_page_by_chunking=args.enable_page_chunking
- )
- elif args.indexer_type == "glossary":
- # Deploy business glossary index
- index_config = BusinessGlossaryAISearch(endpoint, credential)
-
index_config.deploy()
if args.rebuild:
diff --git a/aisearch-skillset/inquiry_document.py b/aisearch-skillset/inquiry_document.py
index 3f9dd0a..b70251e 100644
--- a/aisearch-skillset/inquiry_document.py
+++ b/aisearch-skillset/inquiry_document.py
@@ -63,14 +63,14 @@ def get_index_fields(self) -> list[SearchableField]:
name="Title", type=SearchFieldDataType.String, filterable=True
),
SearchableField(
- name="DealId",
+ name="ID1",
type=SearchFieldDataType.String,
sortable=True,
filterable=True,
facetable=True,
),
SearchableField(
- name="OracleId",
+ name="ID2",
type=SearchFieldDataType.String,
sortable=True,
filterable=True,
@@ -212,12 +212,12 @@ def get_index_projections(self) -> SearchIndexerIndexProjections:
source="/document/Title"
),
InputFieldMappingEntry(
- name="DealId",
- source="/document/DealId"
+ name="ID1",
+ source="/document/ID1"
),
InputFieldMappingEntry(
- name="OracleId",
- source="/document/OracleId"
+ name="ID2",
+ source="/document/ID2"
),
InputFieldMappingEntry(
name="SourceUrl",
@@ -302,9 +302,9 @@ def get_indexer(self) -> SearchIndexer:
FieldMapping(
source_field_name="metadata_storage_name", target_field_name="Title"
),
- FieldMapping(source_field_name="Deal_ID", target_field_name="DealId"),
+ FieldMapping(source_field_name="ID1", target_field_name="ID1"),
FieldMapping(
- source_field_name="Oracle_ID", target_field_name="OracleId"
+ source_field_name="ID2", target_field_name="ID2"
),
FieldMapping(
source_field_name="SharePointUrl", target_field_name="SourceUrl"
diff --git a/function_apps/common/payloads/error.py b/function_apps/common/payloads/error.py
new file mode 100644
index 0000000..49e456e
--- /dev/null
+++ b/function_apps/common/payloads/error.py
@@ -0,0 +1,20 @@
+from typing import Optional
+from pydantic import BaseModel, Field, ConfigDict
+from datetime import datetime, timezone
+
+
+class Error(BaseModel):
+ """Error item model"""
+
+ code: str = Field(..., description="The error code")
+ message: str = Field(..., description="The error message")
+ details: Optional[str] = Field(
+ None, description="Detailed error information from Python"
+ )
+ timestamp: Optional[datetime] = Field(
+ ...,
+ description="Creation timestamp in UTC",
+ default_factory=lambda: datetime.now(timezone.utc),
+ )
+
+ __config__ = ConfigDict(extra="ignore")
diff --git a/function_apps/common/payloads/header.py b/function_apps/common/payloads/header.py
new file mode 100644
index 0000000..e7a521c
--- /dev/null
+++ b/function_apps/common/payloads/header.py
@@ -0,0 +1,40 @@
+from pydantic import BaseModel, Field, ConfigDict
+from datetime import datetime, timezone
+from enum import Enum
+
+
+class DataTypeEnum(Enum):
+ """Type enum"""
+
+ BUSINESS_GLOSSARY = "business_glossary"
+ SUMMARY = "summary"
+
+
+class TaskEnum(Enum):
+ """Task enum"""
+
+ PENDING_INDEX_COMPLETION = "pending_index_completion"
+ PENDING_INDEX_TRIGGER = "pending_index_trigger"
+ PENDING_SUMMARY_GENERATION = "pending_summary_generation"
+
+
+class Header(BaseModel):
+ """Header model"""
+
+ creation_timestamp: datetime = Field(
+ ...,
+ description="Creation timestamp in UTC",
+ default_factory=lambda: datetime.now(timezone.utc),
+ )
+ last_processed_timestamp: datetime = Field(
+ ...,
+ description="Last processed timestamp in UTC",
+ default_factory=lambda: datetime.now(timezone.utc),
+ )
+ retries_remaining: int = Field(
+ description="Number of retries remaining", default=10
+ )
+ data_type: DataTypeEnum = Field(..., description="Data type")
+ task: TaskEnum = Field(..., description="Task name")
+
+ __config__ = ConfigDict(extra="ignore")
diff --git a/function_apps/common/payloads/payload.py b/function_apps/common/payloads/payload.py
new file mode 100644
index 0000000..fb2f4f9
--- /dev/null
+++ b/function_apps/common/payloads/payload.py
@@ -0,0 +1,20 @@
+from pydantic import BaseModel, ConfigDict
+import logging
+
+
+class Payload(BaseModel):
+ """Body model"""
+
+ @classmethod
+ def from_service_bus_message(cls, message):
+ """
+ Create a Payload object from a ServiceBusMessage object.
+
+ :param message: The ServiceBusMessage object.
+ :return: The Body object.
+ """
+ message = message.get_body().decode("utf-8")
+ logging.info(f"ServiceBus message: {message}")
+ return cls.model_validate_json(message)
+
+ __config__ = ConfigDict(extra="ignore")
diff --git a/function_apps/common/payloads/pending_index_completion.py b/function_apps/common/payloads/pending_index_completion.py
new file mode 100644
index 0000000..8aa0335
--- /dev/null
+++ b/function_apps/common/payloads/pending_index_completion.py
@@ -0,0 +1,40 @@
+from pydantic import BaseModel, Field, ConfigDict
+from datetime import datetime, timezone
+from typing import Optional, List
+
+from common.payloads.header import Header
+from common.payloads.error import Error
+from common.payloads.payload import Payload
+
+
+class PendingIndexCompletionBody(BaseModel):
+ """Body model"""
+
+ indexer: str = Field(..., description="The indexer to trigger")
+ deal_id: Optional[int] = Field(None, description="The deal ID")
+ blob_storage_url: Optional[str] = Field(
+ ..., description="The URL to the blob storage"
+ )
+ deal_name: Optional[str] = Field(
+ None, description="The text name for the integer deal ID"
+ )
+ business_unit: Optional[str] = Field(None, description="The business unit")
+ indexer_start_time: Optional[datetime] = Field(
+ ...,
+ description="The time the indexer was triggered successfully",
+ default_factory=lambda: datetime.now(timezone.utc),
+ )
+
+ __config__ = ConfigDict(extra="ignore")
+
+
+class PendingIndexCompletionPayload(Payload):
+ """Pending Index Trigger model"""
+
+ header: Header = Field(..., description="Header information")
+ body: PendingIndexCompletionBody = Field(..., description="Body information")
+ errors: List[Error] | None = Field(
+ ..., description="List of errors", default_factory=list
+ )
+
+ __config__ = ConfigDict(extra="ignore")
diff --git a/function_apps/common/payloads/pennding_index_trigger.py b/function_apps/common/payloads/pennding_index_trigger.py
new file mode 100644
index 0000000..2a519d9
--- /dev/null
+++ b/function_apps/common/payloads/pennding_index_trigger.py
@@ -0,0 +1,32 @@
+from pydantic import BaseModel, Field, ConfigDict
+from typing import Optional, List
+
+from common.payloads.header import Header
+from common.payloads.error import Error
+from common.payloads.payload import Payload
+
+
+class PendingIndexTriggerBody(BaseModel):
+ """Body model"""
+
+ indexer: str = Field(..., description="The indexer to trigger")
+ deal_id: Optional[int] = Field(None, description="The deal ID")
+ blob_storage_url: str = Field(..., description="The URL to the blob storage")
+ deal_name: Optional[str] = Field(
+ None, description="The text name for the integer deal ID"
+ )
+ business_unit: Optional[str] = Field(None, description="The business unit")
+
+ __config__ = ConfigDict(extra="ignore")
+
+
+class PendingIndexTriggerPayload(Payload):
+ """Pending Index Trigger model"""
+
+ header: Header = Field(..., description="Header information")
+ body: PendingIndexTriggerBody = Field(..., description="Body information")
+ errors: List[Error] | None = Field(
+ ..., description="List of errors", default_factory=list
+ )
+
+ __config__ = ConfigDict(extra="ignore")
diff --git a/function_apps/indexer/pending_index_completion.py b/function_apps/indexer/pending_index_completion.py
new file mode 100644
index 0000000..e69de29
From 461702883d1f9c15725d87a396bcb7862a318092 Mon Sep 17 00:00:00 2001
From: priyal1508 <54278892+priyal1508@users.noreply.github.com>
Date: Thu, 5 Sep 2024 19:14:10 +0530
Subject: [PATCH 3/4] fixing bugs
---
.../{pennding_index_trigger.py => pending_index_trigger.py} | 0
1 file changed, 0 insertions(+), 0 deletions(-)
rename function_apps/common/payloads/{pennding_index_trigger.py => pending_index_trigger.py} (100%)
diff --git a/function_apps/common/payloads/pennding_index_trigger.py b/function_apps/common/payloads/pending_index_trigger.py
similarity index 100%
rename from function_apps/common/payloads/pennding_index_trigger.py
rename to function_apps/common/payloads/pending_index_trigger.py
From de48566b1c9c01478ad533e7cc48cca27e420d03 Mon Sep 17 00:00:00 2001
From: priyal1508 <54278892+priyal1508@users.noreply.github.com>
Date: Thu, 5 Sep 2024 19:21:08 +0530
Subject: [PATCH 4/4] changes in fodler structure
---
{aisearch-skillset => ai_search_with_adi}/ai_search.py | 0
{aisearch-skillset => ai_search_with_adi}/deploy.py | 0
{aisearch-skillset => ai_search_with_adi}/environment.py | 0
.../function_apps}/common/ai_search.py | 0
.../function_apps}/common/payloads/error.py | 0
.../function_apps}/common/payloads/header.py | 0
.../function_apps}/common/payloads/payload.py | 0
.../function_apps}/common/payloads/pending_index_completion.py | 0
.../function_apps}/common/payloads/pending_index_trigger.py | 0
.../function_apps}/indexer/adi_2_aisearch.py | 0
.../function_apps}/indexer/function_app.py | 0
.../function_apps}/indexer/key_phrase_extraction.py | 0
.../function_apps}/indexer/pending_index_completion.py | 0
.../function_apps}/indexer/pre_embedding_cleaner.py | 0
.../function_apps}/indexer/requirements.txt | 0
{aisearch-skillset => ai_search_with_adi}/inquiry_document.py | 0
16 files changed, 0 insertions(+), 0 deletions(-)
rename {aisearch-skillset => ai_search_with_adi}/ai_search.py (100%)
rename {aisearch-skillset => ai_search_with_adi}/deploy.py (100%)
rename {aisearch-skillset => ai_search_with_adi}/environment.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/common/ai_search.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/error.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/header.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/payload.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/pending_index_completion.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/common/payloads/pending_index_trigger.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/indexer/adi_2_aisearch.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/indexer/function_app.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/indexer/key_phrase_extraction.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/indexer/pending_index_completion.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/indexer/pre_embedding_cleaner.py (100%)
rename {function_apps => ai_search_with_adi/function_apps}/indexer/requirements.txt (100%)
rename {aisearch-skillset => ai_search_with_adi}/inquiry_document.py (100%)
diff --git a/aisearch-skillset/ai_search.py b/ai_search_with_adi/ai_search.py
similarity index 100%
rename from aisearch-skillset/ai_search.py
rename to ai_search_with_adi/ai_search.py
diff --git a/aisearch-skillset/deploy.py b/ai_search_with_adi/deploy.py
similarity index 100%
rename from aisearch-skillset/deploy.py
rename to ai_search_with_adi/deploy.py
diff --git a/aisearch-skillset/environment.py b/ai_search_with_adi/environment.py
similarity index 100%
rename from aisearch-skillset/environment.py
rename to ai_search_with_adi/environment.py
diff --git a/function_apps/common/ai_search.py b/ai_search_with_adi/function_apps/common/ai_search.py
similarity index 100%
rename from function_apps/common/ai_search.py
rename to ai_search_with_adi/function_apps/common/ai_search.py
diff --git a/function_apps/common/payloads/error.py b/ai_search_with_adi/function_apps/common/payloads/error.py
similarity index 100%
rename from function_apps/common/payloads/error.py
rename to ai_search_with_adi/function_apps/common/payloads/error.py
diff --git a/function_apps/common/payloads/header.py b/ai_search_with_adi/function_apps/common/payloads/header.py
similarity index 100%
rename from function_apps/common/payloads/header.py
rename to ai_search_with_adi/function_apps/common/payloads/header.py
diff --git a/function_apps/common/payloads/payload.py b/ai_search_with_adi/function_apps/common/payloads/payload.py
similarity index 100%
rename from function_apps/common/payloads/payload.py
rename to ai_search_with_adi/function_apps/common/payloads/payload.py
diff --git a/function_apps/common/payloads/pending_index_completion.py b/ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py
similarity index 100%
rename from function_apps/common/payloads/pending_index_completion.py
rename to ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py
diff --git a/function_apps/common/payloads/pending_index_trigger.py b/ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py
similarity index 100%
rename from function_apps/common/payloads/pending_index_trigger.py
rename to ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py
diff --git a/function_apps/indexer/adi_2_aisearch.py b/ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py
similarity index 100%
rename from function_apps/indexer/adi_2_aisearch.py
rename to ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py
diff --git a/function_apps/indexer/function_app.py b/ai_search_with_adi/function_apps/indexer/function_app.py
similarity index 100%
rename from function_apps/indexer/function_app.py
rename to ai_search_with_adi/function_apps/indexer/function_app.py
diff --git a/function_apps/indexer/key_phrase_extraction.py b/ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py
similarity index 100%
rename from function_apps/indexer/key_phrase_extraction.py
rename to ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py
diff --git a/function_apps/indexer/pending_index_completion.py b/ai_search_with_adi/function_apps/indexer/pending_index_completion.py
similarity index 100%
rename from function_apps/indexer/pending_index_completion.py
rename to ai_search_with_adi/function_apps/indexer/pending_index_completion.py
diff --git a/function_apps/indexer/pre_embedding_cleaner.py b/ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py
similarity index 100%
rename from function_apps/indexer/pre_embedding_cleaner.py
rename to ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py
diff --git a/function_apps/indexer/requirements.txt b/ai_search_with_adi/function_apps/indexer/requirements.txt
similarity index 100%
rename from function_apps/indexer/requirements.txt
rename to ai_search_with_adi/function_apps/indexer/requirements.txt
diff --git a/aisearch-skillset/inquiry_document.py b/ai_search_with_adi/inquiry_document.py
similarity index 100%
rename from aisearch-skillset/inquiry_document.py
rename to ai_search_with_adi/inquiry_document.py