@@ -73,12 +73,12 @@ def __check_data_type_embedding_list(attr_value: Any) -> bool:
73
73
return True
74
74
75
75
76
- def __print_progress (progress : float ) -> None :
76
+ def __print_progress_a2vybg (progress : float ) -> None :
77
77
print (f"progress: { progress } " , flush = True )
78
78
79
79
80
- def load_data_dict (record : Dict [str , Any ]) -> Dict [str , Any ]:
81
- global vocab
80
+ def load_data_dict_a2vybg (record : Dict [str , Any ]) -> Dict [str , Any ]:
81
+ global vocab_a2vybg
82
82
83
83
if record ["bytes" ][:2 ] == "\\ x" :
84
84
record ["bytes" ] = record ["bytes" ][2 :]
@@ -87,7 +87,7 @@ def load_data_dict(record: Dict[str, Any]) -> Dict[str, Any]:
87
87
88
88
byte = bytes .fromhex (record ["bytes" ])
89
89
doc_bin_loaded = DocBin ().from_bytes (byte )
90
- docs = list (doc_bin_loaded .get_docs (vocab ))
90
+ docs = list (doc_bin_loaded .get_docs (vocab_a2vybg ))
91
91
data_dict = {}
92
92
for col , doc in zip (record ["columns" ], docs ):
93
93
data_dict [col ] = doc
@@ -99,62 +99,86 @@ def load_data_dict(record: Dict[str, Any]) -> Dict[str, Any]:
99
99
return data_dict
100
100
101
101
102
- def parse_data_to_record_dict (
102
+ def parse_data_to_record_dict_a2vybg (
103
103
record_chunk : List [Dict [str , Any ]]
104
104
) -> List [Dict [str , Any ]]:
105
105
result = []
106
106
for r in record_chunk :
107
- result .append ({"id" : r ["record_id" ], "data" : load_data_dict (r )})
107
+ result .append ({"id" : r ["record_id" ], "data" : load_data_dict_a2vybg (r )})
108
108
return result
109
109
110
110
111
- def save_ac_value (record_id : str , attr_value : Any ) -> None :
112
- global calculated_attribute_by_record_id , processed_records , progress_size , amount
113
- global check_data_type , py_data_types , llm_ac_cache , llm_config_hash , cached_records
111
+ def send_cache_to_object_storage_a2vybg ():
112
+ global llm_ac_cache_a2vybg , llm_config_hash_a2vybg , cached_records_a2vybg
113
+
114
+ if data_type == "LLM_RESPONSE" and "http" in CACHE_FILE_UPLOAD_LINK_A2VYBG :
115
+ llm_ac_cache_a2vybg [llm_config_hash_a2vybg ] = cached_records_a2vybg
116
+ requests .put (CACHE_FILE_UPLOAD_LINK_A2VYBG , json = llm_ac_cache_a2vybg )
117
+
118
+
119
+ def save_ac_value_a2vybg (record_id : str , attr_value : Any ) -> None :
120
+ global calculated_attribute_by_record_id_a2vybg , processed_records_a2vybg , progress_size_a2vybg , amount_a2vybg
121
+ global check_data_type_a2vybg , py_data_types_a2vybg , llm_ac_cache_a2vybg , llm_config_hash_a2vybg , cached_records_a2vybg
114
122
global CACHE_FILE_UPLOAD_LINK_A2VYBG
115
123
116
- if not check_data_type (attr_value ):
124
+ if not check_data_type_a2vybg (attr_value ):
117
125
raise ValueError (
118
126
f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
119
127
f"but data_type { data_type } requires "
120
- f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
128
+ f"{ str (py_data_types_a2vybg ) if len (py_data_types_a2vybg ) > 1 else str (py_data_types_a2vybg [0 ])} ."
121
129
)
122
130
123
- calculated_attribute_by_record_id [record_id ] = attr_value
131
+ calculated_attribute_by_record_id_a2vybg [record_id ] = attr_value
124
132
125
- if data_type == "LLM_RESPONSE" and "http" in CACHE_FILE_UPLOAD_LINK_A2VYBG :
126
- llm_ac_cache [llm_config_hash ] = cached_records
127
- # TODO only save cache every few records to avoid request spamming
128
- requests .put (CACHE_FILE_UPLOAD_LINK_A2VYBG , json = llm_ac_cache )
129
-
130
- processed_records = processed_records + 1
131
- if processed_records % progress_size == 0 :
132
- __print_progress (round (processed_records / amount , 2 ))
133
+ processed_records_a2vybg = processed_records_a2vybg + 1
134
+ if processed_records_a2vybg % progress_size_a2vybg == 0 :
135
+ __print_progress_a2vybg (round (processed_records_a2vybg / amount_a2vybg , 2 ))
136
+ if data_type == "LLM_RESPONSE" and processed_records_a2vybg % 250 == 0 :
137
+ send_cache_to_object_storage_a2vybg ()
133
138
134
139
135
- def process_attribute_calculation (record_dict_list : List [Dict [str , Any ]]) -> None :
140
+ def process_attribute_calculation_a2vybg (
141
+ record_dict_list : List [Dict [str , Any ]]
142
+ ) -> None :
136
143
for record_dict in record_dict_list :
137
144
attr_value : Any = attribute_calculators .ac (record_dict ["data" ])
138
- save_ac_value (record_dict ["id" ], attr_value )
139
-
145
+ save_ac_value_a2vybg (record_dict ["id" ], attr_value )
140
146
141
- async def process_llm_record_batch (record_dict_batch : List [Dict [str , Any ]]) -> None :
142
- global DEFAULT_USER_PROMPT_A2VYBG , cached_records
143
147
144
- for record_dict in record_dict_batch :
145
- attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
146
- DEFAULT_USER_PROMPT_A2VYBG , record_dict
147
- )
148
-
149
- attr_value : str = await attribute_calculators .ac (
150
- record_dict ["data" ], cached_records
151
- )
148
+ def check_abort_status_a2vybg () -> bool :
149
+ # function outside the async loop for reading always the freshest value
150
+ global should_abort_a2vybg
151
+ return should_abort_a2vybg
152
152
153
- save_ac_value (record_dict ["id" ], attr_value )
154
153
154
+ async def process_llm_record_batch_a2vybg (
155
+ record_dict_batch : List [Dict [str , Any ]]
156
+ ) -> None :
157
+ global DEFAULT_USER_PROMPT_A2VYBG , cached_records_a2vybg
155
158
156
- async def process_async_llm_calls (record_dict_list : List [Dict [str , Any ]]) -> None :
157
- global amount
159
+ for record_dict in record_dict_batch :
160
+ if check_abort_status_a2vybg ():
161
+ return
162
+ try :
163
+ attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
164
+ DEFAULT_USER_PROMPT_A2VYBG , record_dict
165
+ )
166
+ attr_value : str = await attribute_calculators .ac (
167
+ record_dict ["data" ], cached_records_a2vybg
168
+ )
169
+
170
+ save_ac_value_a2vybg (record_dict ["id" ], attr_value )
171
+ except Exception as e :
172
+ global should_abort_a2vybg
173
+ should_abort_a2vybg = True
174
+ print (f"Error in record { record_dict ['data' ]['running_id' ]} : { str (e )} " )
175
+ return
176
+
177
+
178
+ async def process_async_llm_calls_a2vybg (
179
+ record_dict_list : List [Dict [str , Any ]]
180
+ ) -> None :
181
+ global amount_a2vybg
158
182
159
183
def make_batches (
160
184
iterable : List [Any ], size : int = 1
@@ -163,12 +187,15 @@ def make_batches(
163
187
for ndx in range (0 , length , size ):
164
188
yield iterable [ndx : min (ndx + size , length )]
165
189
166
- batch_size = max (amount // int (attribute_calculators .NUM_WORKERS_A2VYBG ), 1 )
190
+ batch_size = max (amount_a2vybg // int (attribute_calculators .NUM_WORKERS_A2VYBG ), 1 )
167
191
tasks = [
168
- process_llm_record_batch (batch )
192
+ process_llm_record_batch_a2vybg (batch )
169
193
for batch in make_batches (record_dict_list , size = batch_size )
170
194
]
171
195
await asyncio .gather (* tasks )
196
+ send_cache_to_object_storage_a2vybg ()
197
+ if check_abort_status_a2vybg ():
198
+ raise ValueError ("Encountered error during LLM processing." )
172
199
173
200
174
201
if __name__ == "__main__" :
@@ -192,38 +219,46 @@ def make_batches(
192
219
attribute_calculators , "CACHE_FILE_UPLOAD_LINK_A2VYBG" , ""
193
220
)
194
221
195
- vocab = spacy .blank (iso2_code ).vocab
222
+ vocab_a2vybg = spacy .blank (iso2_code ).vocab
223
+
224
+ should_abort_a2vybg = False
196
225
197
226
with open ("docbin_full.json" , "r" ) as infile :
198
227
docbin_data = json .load (infile )
199
228
200
- record_dict_list = parse_data_to_record_dict (docbin_data )
229
+ record_dict_list = parse_data_to_record_dict_a2vybg (docbin_data )
201
230
202
- py_data_types , check_data_type = get_check_data_type_function (data_type )
231
+ py_data_types_a2vybg , check_data_type_a2vybg = get_check_data_type_function (
232
+ data_type
233
+ )
203
234
204
235
print ("Running attribute calculation." )
205
- calculated_attribute_by_record_id = {}
206
- amount = len (record_dict_list )
207
- progress_size = min (
236
+ calculated_attribute_by_record_id_a2vybg = {}
237
+ amount_a2vybg = len (record_dict_list )
238
+ progress_size_a2vybg = min (
208
239
100 ,
209
- max (amount // int (getattr (attribute_calculators , "NUM_WORKERS_A2VYBG" , 1 )), 1 ),
240
+ max (
241
+ amount_a2vybg
242
+ // int (getattr (attribute_calculators , "NUM_WORKERS_A2VYBG" , 1 )),
243
+ 1 ,
244
+ ),
210
245
)
211
- processed_records = 0
212
- __print_progress (0.0 )
246
+ processed_records_a2vybg = 0
247
+ __print_progress_a2vybg (0.0 )
213
248
214
249
if data_type == "LLM_RESPONSE" :
215
- llm_config = attribute_calculators .get_llm_config ()
250
+ llm_config = attribute_calculators .get_llm_config_a2vybg ()
216
251
if "http" in CACHE_ACCESS_LINK_A2VYBG :
217
- llm_ac_cache = requests .get (CACHE_ACCESS_LINK_A2VYBG ).json ()
252
+ llm_ac_cache_a2vybg = requests .get (CACHE_ACCESS_LINK_A2VYBG ).json ()
218
253
else :
219
- llm_ac_cache = {}
220
- llm_config_hash = md5 (json .dumps (llm_config ).encode ()).hexdigest ()
254
+ llm_ac_cache_a2vybg = {}
255
+ llm_config_hash_a2vybg = md5 (json .dumps (llm_config ).encode ()).hexdigest ()
221
256
222
- cached_records = llm_ac_cache .get (llm_config_hash , {})
223
- asyncio .run (process_async_llm_calls (record_dict_list ))
257
+ cached_records_a2vybg = llm_ac_cache_a2vybg .get (llm_config_hash_a2vybg , {})
258
+ asyncio .run (process_async_llm_calls_a2vybg (record_dict_list ))
224
259
else :
225
- process_attribute_calculation (record_dict_list )
260
+ process_attribute_calculation_a2vybg (record_dict_list )
226
261
227
- __print_progress (1.0 )
262
+ __print_progress_a2vybg (1.0 )
228
263
print ("Finished execution." )
229
- requests .put (payload_url , json = calculated_attribute_by_record_id )
264
+ requests .put (payload_url , json = calculated_attribute_by_record_id_a2vybg )
0 commit comments