Skip to content

Commit f880826

Browse files
committed
changes for comparison with old rag
1 parent ade6997 commit f880826

File tree

8 files changed

+1234
-2
lines changed

8 files changed

+1234
-2
lines changed

adi_function_app/function_app.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from adi_2_ai_search import process_adi_2_ai_search
99
from pre_embedding_cleaner import process_pre_embedding_cleaner
1010
from key_phrase_extraction import process_key_phrase_extraction
11+
from ocr import process_ocr
1112

1213
logging.basicConfig(level=logging.DEBUG)
1314
app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION)
@@ -124,3 +125,41 @@ async def key_phrase_extractor(req: func.HttpRequest) -> func.HttpResponse:
124125
status_code=200,
125126
mimetype="application/json",
126127
)
128+
129+
@app.route(route="ocr", methods=[func.HttpMethod.POST])
130+
async def ocr(req: func.HttpRequest) -> func.HttpResponse:
131+
"""HTTP trigger for data cleanup function.
132+
133+
Args:
134+
req (func.HttpRequest): The HTTP request object.
135+
136+
Returns:
137+
func.HttpResponse: The HTTP response object."""
138+
logging.info("Python HTTP trigger data cleanup function processed a request.")
139+
140+
try:
141+
req_body = req.get_json()
142+
values = req_body.get("values")
143+
logging.info(req_body)
144+
except ValueError:
145+
return func.HttpResponse(
146+
"Please valid Custom Skill Payload in the request body", status_code=400
147+
)
148+
else:
149+
logging.debug("Input Values: %s", values)
150+
151+
record_tasks = []
152+
153+
for value in values:
154+
record_tasks.append(
155+
asyncio.create_task(process_ocr(value))
156+
)
157+
158+
results = await asyncio.gather(*record_tasks)
159+
logging.debug("Results: %s", results)
160+
161+
return func.HttpResponse(
162+
json.dumps({"values": results}),
163+
status_code=200,
164+
mimetype="application/json",
165+
)

adi_function_app/ocr.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import logging
2+
import os
3+
from azure.ai.vision.imageanalysis.aio import ImageAnalysisClient
4+
from azure.ai.vision.imageanalysis.models import VisualFeatures
5+
from azure.core.credentials import AzureKeyCredential
6+
7+
8+
async def process_ocr(record: dict) -> dict:
9+
logging.info("Python HTTP trigger function processed a request.")
10+
11+
try:
12+
url = record["data"]["image"]["url"]
13+
logging.info(f"Request Body: {record}")
14+
except KeyError:
15+
return {
16+
"recordId": record["recordId"],
17+
"data": {},
18+
"errors": [
19+
{
20+
"message": "Failed to extract data with ocr. Pass a valid source in the request body.",
21+
}
22+
],
23+
"warnings": None,
24+
}
25+
else:
26+
logging.info(f"image url: {url}")
27+
28+
if url is not None:
29+
try:
30+
# keyvault_helper = KeyVaultHelper()
31+
client = ImageAnalysisClient(
32+
endpoint=os.environ["AIService__Services__Endpoint"],
33+
credential=AzureKeyCredential(os.environ["AIService__Services__Key"])
34+
),
35+
result = await client.analyze_from_url(
36+
image_url=url, visual_features=[VisualFeatures.READ]
37+
)
38+
logging.info("logging output")
39+
40+
# Extract text from OCR results
41+
text = " ".join([line.text for line in result.read.blocks[0].lines])
42+
logging.info(text)
43+
44+
except KeyError as e:
45+
logging.error(e)
46+
logging.error(f"Failed to authenticate with ocr: {e}")
47+
return {
48+
"recordId": record["recordId"],
49+
"data": {},
50+
"errors": [
51+
{
52+
"message": f"Failed to authenticate with Ocr. Check the service credentials exist. {e}",
53+
}
54+
],
55+
"warnings": None,
56+
}
57+
except Exception as e:
58+
logging.error(e)
59+
logging.error(
60+
f"Failed to analyze the document with Azure Document Intelligence: {e}"
61+
)
62+
logging.error(e.InnerError)
63+
return {
64+
"recordId": record["recordId"],
65+
"data": {},
66+
"errors": [
67+
{
68+
"message": f"Failed to analyze the document with ocr. Check the source and try again. {e}",
69+
}
70+
],
71+
"warnings": None,
72+
}
73+
else:
74+
return {
75+
"recordId": record["recordId"],
76+
"data": {"text": ""},
77+
}
78+
79+
return {
80+
"recordId": record["recordId"],
81+
"data": {"text": text},
82+
}

deploy_ai_search/ai_search.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
InputFieldMappingEntry,
2525
SynonymMap,
2626
SplitSkill,
27+
DocumentExtractionSkill,
28+
OcrSkill,
29+
MergeSkill,
2730
SearchIndexerIndexProjections,
2831
BlobIndexerParsingMode,
2932
)
@@ -147,7 +150,7 @@ def get_indexer(self) -> SearchIndexer:
147150
return None
148151

149152
def get_index_projections(self) -> SearchIndexerIndexProjections:
150-
"""Get the index projections for the indexer."""
153+
"""Get the index projections for the indexer."""
151154

152155
return None
153156

@@ -420,6 +423,108 @@ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
420423

421424
return key_phrase_extraction_skill
422425

426+
def get_document_extraction_skill(self, context, source) -> DocumentExtractionSkill:
427+
"""Get the document extraction utility skill.
428+
429+
Args:
430+
-----
431+
context (str): The context of the skill
432+
source (str): The source of the skill
433+
434+
Returns:
435+
--------
436+
DocumentExtractionSkill: The document extraction utility skill"""
437+
438+
doc_extraction_skill = DocumentExtractionSkill(
439+
description="Extraction skill to extract content from office docs like excel, ppt, doc etc",
440+
context=context,
441+
inputs=[InputFieldMappingEntry(name="file_data", source=source)],
442+
outputs=[
443+
OutputFieldMappingEntry(
444+
name="content", target_name="extracted_content"
445+
),
446+
OutputFieldMappingEntry(
447+
name="normalized_images", target_name="extracted_normalized_images"
448+
),
449+
],
450+
)
451+
452+
return doc_extraction_skill
453+
454+
def get_ocr_skill(self, context, source) -> OcrSkill:
455+
"""Get the ocr utility skill
456+
Args:
457+
-----
458+
context (str): The context of the skill
459+
source (str): The source of the skill
460+
461+
Returns:
462+
--------
463+
OcrSkill: The ocr skill"""
464+
465+
if self.test:
466+
batch_size = 2
467+
degree_of_parallelism = 2
468+
else:
469+
batch_size = 2
470+
degree_of_parallelism = 2
471+
472+
ocr_skill_inputs = [
473+
InputFieldMappingEntry(name="image", source=source),
474+
]
475+
ocr__skill_outputs = [OutputFieldMappingEntry(name="text", target_name="text")]
476+
ocr_skill = WebApiSkill(
477+
name="ocr API",
478+
description="Skill to extract text from images",
479+
context=context,
480+
uri=self.environment.get_custom_skill_function_url("ocr"),
481+
timeout="PT230S",
482+
batch_size=batch_size,
483+
degree_of_parallelism=degree_of_parallelism,
484+
http_method="POST",
485+
inputs=ocr_skill_inputs,
486+
outputs=ocr__skill_outputs,
487+
)
488+
489+
if self.environment.identity_type != IdentityType.KEY:
490+
ocr_skill.auth_identity = (
491+
self.environment.function_app_app_registration_resource_id
492+
)
493+
494+
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
495+
ocr_skill.auth_identity = (
496+
self.environment.ai_search_user_assigned_identity
497+
)
498+
499+
return ocr_skill
500+
501+
def get_merge_skill(self, context, source) -> MergeSkill:
502+
"""Get the merge
503+
Args:
504+
-----
505+
context (str): The context of the skill
506+
source (array): The source of the skill
507+
508+
Returns:
509+
--------
510+
mergeSkill: The merge skill"""
511+
512+
merge_skill = MergeSkill(
513+
description="Merge skill for combining OCR'd and regular text",
514+
context=context,
515+
inputs=[
516+
InputFieldMappingEntry(name="text", source=source[0]),
517+
InputFieldMappingEntry(name="itemsToInsert", source=source[1]),
518+
InputFieldMappingEntry(name="offsets", source=source[2]),
519+
],
520+
outputs=[
521+
OutputFieldMappingEntry(name="mergedText", target_name="merged_content")
522+
],
523+
)
524+
525+
return merge_skill
526+
527+
423528
def get_vector_search(self) -> VectorSearch:
424529
"""Get the vector search configuration for compass.
425530

deploy_ai_search/deploy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
import argparse
4-
from rag_documents import RagDocumentsAISearch
4+
# from rag_documents import RagDocumentsAISearch
5+
from rag_documents_old import RagDocumentsAISearch
56
from text_2_sql import Text2SqlAISearch
67
from text_2_sql_query_cache import Text2SqlQueryCacheAISearch
78

deploy_ai_search/environment.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ def function_app_key_phrase_extractor_route(self) -> str:
217217
This function returns function app keyphrase extractor name
218218
"""
219219
return os.environ.get("FunctionApp__KeyPhraseExtractor__FunctionName")
220+
221+
@property
222+
def function_app_key_ocr_route(self) -> str:
223+
"""
224+
This function returns function app keyphrase extractor name
225+
"""
226+
return os.environ.get("FunctionApp__Ocr__FunctionName")
220227

221228
@property
222229
def open_ai_embedding_dimensions(self) -> str:
@@ -249,6 +256,8 @@ def get_custom_skill_function_url(self, skill_type: str):
249256
route = self.function_app_adi_route
250257
elif skill_type == "key_phrase_extraction":
251258
route = self.function_app_key_phrase_extractor_route
259+
elif skill_type == "ocr":
260+
route = self.function_app_key_ocr_route
252261
else:
253262
raise ValueError(f"Invalid skill type: {skill_type}")
254263

deploy_ai_search/rag_documents.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def get_skills(self) -> list:
162162
163163
Returns:
164164
list: The skillsets used in the indexer"""
165+
165166

166167
adi_skill = self.get_adi_skill(self.enable_page_by_chunking)
167168

0 commit comments

Comments
 (0)