20
20
SearchIndexerDataContainer ,
21
21
SearchIndexerDataSourceConnection ,
22
22
SearchIndexerDataSourceType ,
23
- SearchIndexerDataUserAssignedIdentity ,
24
23
OutputFieldMappingEntry ,
25
24
InputFieldMappingEntry ,
26
25
SynonymMap ,
29
28
)
30
29
from azure .core .exceptions import HttpResponseError
31
30
from azure .search .documents .indexes import SearchIndexerClient , SearchIndexClient
32
- from ai_search_with_adi .ai_search .environment import (
33
- get_fq_blob_connection_string ,
34
- get_blob_container_name ,
35
- get_custom_skill_function_url ,
36
- get_managed_identity_fqname ,
37
- get_function_app_authresourceid ,
38
- )
31
+ from ai_search_with_adi .ai_search .environment import AISearchEnvironment , IdentityType
39
32
40
33
41
34
class AISearch (ABC ):
35
+ """Handles the deployment of the AI search pipeline."""
36
+
42
37
def __init__ (
43
38
self ,
44
- endpoint : str ,
45
- credential ,
46
39
suffix : str | None = None ,
47
40
rebuild : bool | None = False ,
48
41
):
49
42
"""Initialize the AI search class
50
43
51
44
Args:
52
- endpoint (str): The search endpoint
53
- credential (AzureKeyCredential): The search credential
54
- suffix (str, optional): The suffix for the indexer. Defaults to None.
45
+ suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer.
55
46
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
56
47
"""
57
48
self .indexer_type = None
@@ -61,15 +52,22 @@ def __init__(
61
52
else :
62
53
self .rebuild = False
63
54
55
+ # If suffix is None, then it is not a test indexer. Test indexer limits the rate of indexing and turns off the schedule. Useful for testing index changes
64
56
if suffix is None :
65
57
self .suffix = ""
66
58
self .test = False
67
59
else :
68
60
self .suffix = f"-{ suffix } -test"
69
61
self .test = True
70
62
71
- self ._search_indexer_client = SearchIndexerClient (endpoint , credential )
72
- self ._search_index_client = SearchIndexClient (endpoint , credential )
63
+ self .environment = AISearchEnvironment (indexer_type = self .indexer_type )
64
+
65
+ self ._search_indexer_client = SearchIndexerClient (
66
+ self .environment .ai_search_endpoint , self .environment .ai_search_credential
67
+ )
68
+ self ._search_index_client = SearchIndexClient (
69
+ self .environment .ai_search_endpoint , self .environment .ai_search_credential
70
+ )
73
71
74
72
@property
75
73
def indexer_name (self ):
@@ -94,7 +92,7 @@ def index_name(self):
94
92
@property
95
93
def data_source_name (self ):
96
94
"""Get the data source name for the indexer."""
97
- blob_container_name = get_blob_container_name ( self .indexer_type )
95
+ blob_container_name = self .environment . get_blob_container_name ( )
98
96
return f"{ blob_container_name } -data-source{ self .suffix } "
99
97
100
98
@property
@@ -146,16 +144,6 @@ def get_synonym_map_names(self) -> list[str]:
146
144
"""Get the synonym map names for the indexer."""
147
145
return []
148
146
149
- def get_user_assigned_managed_identity (
150
- self ,
151
- ) -> SearchIndexerDataUserAssignedIdentity :
152
- """Get user assigned managed identity details"""
153
-
154
- user_assigned_identity = SearchIndexerDataUserAssignedIdentity (
155
- user_assigned_identity = get_managed_identity_fqname ()
156
- )
157
- return user_assigned_identity
158
-
159
147
def get_data_source (self ) -> SearchIndexerDataSourceConnection :
160
148
"""Get the data source for the indexer."""
161
149
@@ -166,19 +154,21 @@ def get_data_source(self) -> SearchIndexerDataSourceConnection:
166
154
)
167
155
168
156
container = SearchIndexerDataContainer (
169
- name = get_blob_container_name ( self .indexer_type )
157
+ name = self .environment . get_blob_container_name ( )
170
158
)
171
159
172
160
data_source_connection = SearchIndexerDataSourceConnection (
173
161
name = self .data_source_name ,
174
162
type = SearchIndexerDataSourceType .AZURE_BLOB ,
175
- connection_string = get_fq_blob_connection_string () ,
163
+ connection_string = self . environment . blob_connection_string ,
176
164
container = container ,
177
165
data_change_detection_policy = data_change_detection_policy ,
178
166
data_deletion_detection_policy = data_deletion_detection_policy ,
179
- identity = self .get_user_assigned_managed_identity (),
180
167
)
181
168
169
+ if self .environment .identity_type != IdentityType .KEY :
170
+ data_source_connection .identity = self .environment .ai_search_identity_id
171
+
182
172
return data_source_connection
183
173
184
174
def get_pre_embedding_cleaner_skill (
@@ -226,17 +216,25 @@ def get_pre_embedding_cleaner_skill(
226
216
name = "Pre Embedding Cleaner Skill" ,
227
217
description = "Skill to clean the data before sending to embedding" ,
228
218
context = context ,
229
- uri = get_custom_skill_function_url ("pre_embedding_cleaner" ),
219
+ uri = self . environment . get_custom_skill_function_url ("pre_embedding_cleaner" ),
230
220
timeout = "PT230S" ,
231
221
batch_size = batch_size ,
232
222
degree_of_parallelism = degree_of_parallelism ,
233
223
http_method = "POST" ,
234
224
inputs = pre_embedding_cleaner_skill_inputs ,
235
225
outputs = pre_embedding_cleaner_skill_outputs ,
236
- auth_resource_id = get_function_app_authresourceid (),
237
- auth_identity = self .get_user_assigned_managed_identity (),
238
226
)
239
227
228
+ if self .environment .identity_type != IdentityType .KEY :
229
+ pre_embedding_cleaner_skill .auth_identity = (
230
+ self .environment .ai_search_identity_id
231
+ )
232
+
233
+ if self .environment .identity_type == IdentityType .USER_ASSIGNED :
234
+ pre_embedding_cleaner_skill .auth_resource_id = (
235
+ self .environment .ai_search_user_assigned_identity
236
+ )
237
+
240
238
return pre_embedding_cleaner_skill
241
239
242
240
def get_text_split_skill (self , context , source ) -> SplitSkill :
@@ -294,7 +292,7 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
294
292
name = "ADI Skill" ,
295
293
description = "Skill to generate ADI" ,
296
294
context = "/document" ,
297
- uri = get_custom_skill_function_url ("adi" ),
295
+ uri = self . environment . get_custom_skill_function_url ("adi" ),
298
296
timeout = "PT230S" ,
299
297
batch_size = batch_size ,
300
298
degree_of_parallelism = degree_of_parallelism ,
@@ -306,10 +304,16 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
306
304
)
307
305
],
308
306
outputs = output ,
309
- auth_resource_id = get_function_app_authresourceid (),
310
- auth_identity = self .get_user_assigned_managed_identity (),
311
307
)
312
308
309
+ if self .environment .identity_type != IdentityType .KEY :
310
+ adi_skill .auth_identity = self .environment .ai_search_identity_id
311
+
312
+ if self .environment .identity_type == IdentityType .USER_ASSIGNED :
313
+ adi_skill .auth_resource_id = (
314
+ self .environment .ai_search_user_assigned_identity
315
+ )
316
+
313
317
return adi_skill
314
318
315
319
def get_vector_skill (
@@ -368,17 +372,25 @@ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
368
372
name = "Key phrase extraction API" ,
369
373
description = "Skill to extract keyphrases" ,
370
374
context = context ,
371
- uri = get_custom_skill_function_url ("keyphraseextraction " ),
375
+ uri = self . environment . get_custom_skill_function_url ("key_phrase_extraction " ),
372
376
timeout = "PT230S" ,
373
377
batch_size = batch_size ,
374
378
degree_of_parallelism = degree_of_parallelism ,
375
379
http_method = "POST" ,
376
380
inputs = keyphrase_extraction_skill_inputs ,
377
381
outputs = keyphrase_extraction__skill_outputs ,
378
- auth_resource_id = get_function_app_authresourceid (),
379
- auth_identity = self .get_user_assigned_managed_identity (),
380
382
)
381
383
384
+ if self .environment .identity_type != IdentityType .KEY :
385
+ key_phrase_extraction_skill .auth_identity = (
386
+ self .environment .ai_search_identity_id
387
+ )
388
+
389
+ if self .environment .identity_type == IdentityType .USER_ASSIGNED :
390
+ key_phrase_extraction_skill .auth_resource_id = (
391
+ self .environment .ai_search_user_assigned_identity
392
+ )
393
+
382
394
return key_phrase_extraction_skill
383
395
384
396
def get_vector_search (self ) -> VectorSearch :
0 commit comments