Skip to content

Commit 42adc2a

Browse files
committed
Update the deployment script
1 parent ad8684f commit 42adc2a

File tree

3 files changed

+78
-81
lines changed

3 files changed

+78
-81
lines changed

ai_search_with_adi/ai_search/ai_search.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
SearchIndexerDataContainer,
2121
SearchIndexerDataSourceConnection,
2222
SearchIndexerDataSourceType,
23-
SearchIndexerDataUserAssignedIdentity,
2423
OutputFieldMappingEntry,
2524
InputFieldMappingEntry,
2625
SynonymMap,
@@ -29,29 +28,21 @@
2928
)
3029
from azure.core.exceptions import HttpResponseError
3130
from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient
32-
from ai_search_with_adi.ai_search.environment import (
33-
get_fq_blob_connection_string,
34-
get_blob_container_name,
35-
get_custom_skill_function_url,
36-
get_managed_identity_fqname,
37-
get_function_app_authresourceid,
38-
)
31+
from ai_search_with_adi.ai_search.environment import AISearchEnvironment, IdentityType
3932

4033

4134
class AISearch(ABC):
35+
"""Handles the deployment of the AI search pipeline."""
36+
4237
def __init__(
4338
self,
44-
endpoint: str,
45-
credential,
4639
suffix: str | None = None,
4740
rebuild: bool | None = False,
4841
):
4942
"""Initialize the AI search class
5043
5144
Args:
52-
endpoint (str): The search endpoint
53-
credential (AzureKeyCredential): The search credential
54-
suffix (str, optional): The suffix for the indexer. Defaults to None.
45+
suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer.
5546
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
5647
"""
5748
self.indexer_type = None
@@ -61,15 +52,22 @@ def __init__(
6152
else:
6253
self.rebuild = False
6354

55+
# If suffix is None, then it is not a test indexer. Test indexer limits the rate of indexing and turns off the schedule. Useful for testing index changes
6456
if suffix is None:
6557
self.suffix = ""
6658
self.test = False
6759
else:
6860
self.suffix = f"-{suffix}-test"
6961
self.test = True
7062

71-
self._search_indexer_client = SearchIndexerClient(endpoint, credential)
72-
self._search_index_client = SearchIndexClient(endpoint, credential)
63+
self.environment = AISearchEnvironment(indexer_type=self.indexer_type)
64+
65+
self._search_indexer_client = SearchIndexerClient(
66+
self.environment.ai_search_endpoint, self.environment.ai_search_credential
67+
)
68+
self._search_index_client = SearchIndexClient(
69+
self.environment.ai_search_endpoint, self.environment.ai_search_credential
70+
)
7371

7472
@property
7573
def indexer_name(self):
@@ -94,7 +92,7 @@ def index_name(self):
9492
@property
9593
def data_source_name(self):
9694
"""Get the data source name for the indexer."""
97-
blob_container_name = get_blob_container_name(self.indexer_type)
95+
blob_container_name = self.environment.get_blob_container_name()
9896
return f"{blob_container_name}-data-source{self.suffix}"
9997

10098
@property
@@ -146,16 +144,6 @@ def get_synonym_map_names(self) -> list[str]:
146144
"""Get the synonym map names for the indexer."""
147145
return []
148146

149-
def get_user_assigned_managed_identity(
150-
self,
151-
) -> SearchIndexerDataUserAssignedIdentity:
152-
"""Get user assigned managed identity details"""
153-
154-
user_assigned_identity = SearchIndexerDataUserAssignedIdentity(
155-
user_assigned_identity=get_managed_identity_fqname()
156-
)
157-
return user_assigned_identity
158-
159147
def get_data_source(self) -> SearchIndexerDataSourceConnection:
160148
"""Get the data source for the indexer."""
161149

@@ -166,19 +154,21 @@ def get_data_source(self) -> SearchIndexerDataSourceConnection:
166154
)
167155

168156
container = SearchIndexerDataContainer(
169-
name=get_blob_container_name(self.indexer_type)
157+
name=self.environment.get_blob_container_name()
170158
)
171159

172160
data_source_connection = SearchIndexerDataSourceConnection(
173161
name=self.data_source_name,
174162
type=SearchIndexerDataSourceType.AZURE_BLOB,
175-
connection_string=get_fq_blob_connection_string(),
163+
connection_string=self.environment.blob_connection_string,
176164
container=container,
177165
data_change_detection_policy=data_change_detection_policy,
178166
data_deletion_detection_policy=data_deletion_detection_policy,
179-
identity=self.get_user_assigned_managed_identity(),
180167
)
181168

169+
if self.environment.identity_type != IdentityType.KEY:
170+
data_source_connection.identity = self.environment.ai_search_identity_id
171+
182172
return data_source_connection
183173

184174
def get_pre_embedding_cleaner_skill(
@@ -226,17 +216,25 @@ def get_pre_embedding_cleaner_skill(
226216
name="Pre Embedding Cleaner Skill",
227217
description="Skill to clean the data before sending to embedding",
228218
context=context,
229-
uri=get_custom_skill_function_url("pre_embedding_cleaner"),
219+
uri=self.environment.get_custom_skill_function_url("pre_embedding_cleaner"),
230220
timeout="PT230S",
231221
batch_size=batch_size,
232222
degree_of_parallelism=degree_of_parallelism,
233223
http_method="POST",
234224
inputs=pre_embedding_cleaner_skill_inputs,
235225
outputs=pre_embedding_cleaner_skill_outputs,
236-
auth_resource_id=get_function_app_authresourceid(),
237-
auth_identity=self.get_user_assigned_managed_identity(),
238226
)
239227

228+
if self.environment.identity_type != IdentityType.KEY:
229+
pre_embedding_cleaner_skill.auth_identity = (
230+
self.environment.ai_search_identity_id
231+
)
232+
233+
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
234+
pre_embedding_cleaner_skill.auth_resource_id = (
235+
self.environment.ai_search_user_assigned_identity
236+
)
237+
240238
return pre_embedding_cleaner_skill
241239

242240
def get_text_split_skill(self, context, source) -> SplitSkill:
@@ -294,7 +292,7 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
294292
name="ADI Skill",
295293
description="Skill to generate ADI",
296294
context="/document",
297-
uri=get_custom_skill_function_url("adi"),
295+
uri=self.environment.get_custom_skill_function_url("adi"),
298296
timeout="PT230S",
299297
batch_size=batch_size,
300298
degree_of_parallelism=degree_of_parallelism,
@@ -306,10 +304,16 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
306304
)
307305
],
308306
outputs=output,
309-
auth_resource_id=get_function_app_authresourceid(),
310-
auth_identity=self.get_user_assigned_managed_identity(),
311307
)
312308

309+
if self.environment.identity_type != IdentityType.KEY:
310+
adi_skill.auth_identity = self.environment.ai_search_identity_id
311+
312+
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
313+
adi_skill.auth_resource_id = (
314+
self.environment.ai_search_user_assigned_identity
315+
)
316+
313317
return adi_skill
314318

315319
def get_vector_skill(
@@ -368,17 +372,25 @@ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
368372
name="Key phrase extraction API",
369373
description="Skill to extract keyphrases",
370374
context=context,
371-
uri=get_custom_skill_function_url("keyphraseextraction"),
375+
uri=self.environment.get_custom_skill_function_url("key_phrase_extraction"),
372376
timeout="PT230S",
373377
batch_size=batch_size,
374378
degree_of_parallelism=degree_of_parallelism,
375379
http_method="POST",
376380
inputs=keyphrase_extraction_skill_inputs,
377381
outputs=keyphrase_extraction__skill_outputs,
378-
auth_resource_id=get_function_app_authresourceid(),
379-
auth_identity=self.get_user_assigned_managed_identity(),
380382
)
381383

384+
if self.environment.identity_type != IdentityType.KEY:
385+
key_phrase_extraction_skill.auth_identity = (
386+
self.environment.ai_search_identity_id
387+
)
388+
389+
if self.environment.identity_type == IdentityType.USER_ASSIGNED:
390+
key_phrase_extraction_skill.auth_resource_id = (
391+
self.environment.ai_search_user_assigned_identity
392+
)
393+
382394
return key_phrase_extraction_skill
383395

384396
def get_vector_search(self) -> VectorSearch:
Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,26 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
43
import argparse
5-
from ai_search_with_adi.ai_search.environment import (
6-
get_search_endpoint,
7-
get_managed_identity_id,
8-
get_search_key,
9-
get_key_vault_url,
10-
)
11-
from azure.core.credentials import AzureKeyCredential
12-
from azure.identity import DefaultAzureCredential
13-
from azure.keyvault.secrets import SecretClient
144
from ai_search_with_adi.ai_search.rag_documents import RagDocumentsAISearch
155

166

17-
def main(args):
18-
endpoint = get_search_endpoint()
19-
20-
try:
21-
credential = DefaultAzureCredential(
22-
managed_identity_client_id=get_managed_identity_id()
23-
)
24-
# initializing key vault client
25-
client = SecretClient(vault_url=get_key_vault_url(), credential=credential)
26-
print("Using managed identity credential")
27-
except Exception as e:
28-
print(e)
29-
credential = AzureKeyCredential(get_search_key(client=client))
30-
print("Using Azure Key credential")
7+
def deploy_config(arguments: argparse.Namespace):
8+
"""Deploy the indexer configuration based on the arguments passed.
319
32-
if args.indexer_type == "rag":
33-
# Deploy the inquiry index
10+
Args:
11+
arguments (argparse.Namespace): The arguments passed to the script"""
12+
if arguments.indexer_type == "rag":
3413
index_config = RagDocumentsAISearch(
35-
endpoint=endpoint,
36-
credential=credential,
37-
suffix=args.suffix,
38-
rebuild=args.rebuild,
39-
enable_page_by_chunking=args.enable_page_chunking,
14+
suffix=arguments.suffix,
15+
rebuild=arguments.rebuild,
16+
enable_page_by_chunking=arguments.enable_page_chunking,
4017
)
4118
else:
4219
raise ValueError("Invalid Indexer Type")
4320

4421
index_config.deploy()
4522

46-
if args.rebuild:
23+
if arguments.rebuild:
4724
index_config.reset_indexer()
4825

4926

@@ -75,4 +52,4 @@ def main(args):
7552
)
7653

7754
args = parser.parse_args()
78-
main(args)
55+
deploy_config(args)

ai_search_with_adi/ai_search/rag_documents.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
)
2525
from ai_search import AISearch
2626
from ai_search_with_adi.ai_search.environment import (
27-
get_search_embedding_model_dimensions,
2827
IndexerType,
2928
)
3029

@@ -34,13 +33,17 @@ class RagDocumentsAISearch(AISearch):
3433

3534
def __init__(
3635
self,
37-
endpoint,
38-
credential,
39-
suffix=None,
40-
rebuild=False,
36+
suffix: str | None = None,
37+
rebuild: bool | None = False,
4138
enable_page_by_chunking=False,
4239
):
43-
super().__init__(endpoint, credential, suffix, rebuild)
40+
"""Initialize the RagDocumentsAISearch class. This class implements the deployment of the rag document index.
41+
42+
Args:
43+
suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer.
44+
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
45+
"""
46+
super().__init__(suffix, rebuild)
4447

4548
self.indexer_type = IndexerType.RAG_DOCUMENTS
4649
if enable_page_by_chunking is not None:
@@ -80,9 +83,7 @@ def get_index_fields(self) -> list[SearchableField]:
8083
SearchField(
8184
name="ChunkEmbedding",
8285
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
83-
vector_search_dimensions=get_search_embedding_model_dimensions(
84-
self.indexer_type
85-
),
86+
vector_search_dimensions=self.environment.embedding_model_dimensions,
8687
vector_search_profile_name=self.vector_search_profile_name,
8788
),
8889
SearchableField(
@@ -224,19 +225,26 @@ def get_indexer(self) -> SearchIndexer:
224225
225226
Returns:
226227
SearchIndexer: The indexer for inquiry document"""
228+
229+
# Only place on schedule if it is not a test deployment
227230
if self.test:
228231
schedule = None
229232
batch_size = 4
230233
else:
231234
schedule = {"interval": "PT15M"}
232235
batch_size = 16
233236

237+
if self.environment.use_private_endpoint:
238+
execution_environment = IndexerExecutionEnvironment.PRIVATE
239+
else:
240+
execution_environment = IndexerExecutionEnvironment.STANDARD
241+
234242
indexer_parameters = IndexingParameters(
235243
batch_size=batch_size,
236244
configuration=IndexingParametersConfiguration(
237245
data_to_extract=BlobIndexerDataToExtract.ALL_METADATA,
238246
query_timeout=None,
239-
execution_environment=IndexerExecutionEnvironment.PRIVATE,
247+
execution_environment=execution_environment,
240248
fail_on_unprocessable_document=False,
241249
fail_on_unsupported_content_type=False,
242250
index_storage_metadata_only_for_oversized_documents=True,

0 commit comments

Comments
 (0)