2
2
import requests
3
3
import spacy
4
4
import sys
5
+ import asyncio
5
6
from mustache import prepare_and_render_mustache
6
7
from spacy .tokens import DocBin
7
8
@@ -110,7 +111,8 @@ def parse_data_to_record_dict(record_chunk):
110
111
# the script `labeling_functions` does not exist. It will be inserted at runtime
111
112
import attribute_calculators
112
113
113
- DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators .USER_PROMPT_A2VYBG
114
+ if data_type == "LLM_RESPONSE" :
115
+ DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators .USER_PROMPT_A2VYBG
114
116
115
117
vocab = spacy .blank (iso2_code ).vocab
116
118
@@ -127,23 +129,53 @@ def parse_data_to_record_dict(record_chunk):
127
129
progress_size = 100
128
130
amount = len (record_dict_list )
129
131
__print_progress (0.0 )
130
- for record_dict in record_dict_list :
131
- attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
132
- DEFAULT_USER_PROMPT_A2VYBG , record_dict
133
- )
134
-
135
- idx += 1
136
- if idx % progress_size == 0 :
137
- progress = round (idx / amount , 2 )
138
- __print_progress (progress )
139
- attr_value = attribute_calculators .ac (record_dict ["data" ])
140
- if not check_data_type (attr_value ):
141
- raise ValueError (
142
- f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
143
- f"but data_type { data_type } requires "
144
- f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
132
+
133
+ async def process_llm_record_batch (record_dict_batch : list ):
134
+ """Process a batch of record_dicts, writes results into shared var calculated_attribute_by_record_id."""
135
+
136
+ for record_dict in record_dict_batch :
137
+ attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
138
+ DEFAULT_USER_PROMPT_A2VYBG , record_dict
145
139
)
146
- calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
147
- __print_progress (1.0 )
148
- print ("Finished execution." )
149
- requests .put (payload_url , json = calculated_attribute_by_record_id )
140
+
141
+ attr_value : str = await attribute_calculators .ac (record_dict ["data" ])
142
+
143
+ if not check_data_type (attr_value ):
144
+ raise ValueError (
145
+ f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
146
+ f"but data_type { data_type } requires "
147
+ f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
148
+ )
149
+ calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
150
+
151
+ async def process_async_llm_calls (record_dict_list ):
152
+ batch_size = len (record_dict_list ) // int (attribute_calculators .NUM_WORKERS )
153
+ record_dict_batches = [
154
+ record_dict_list [i : i + batch_size ]
155
+ for i in range (0 , len (record_dict_list ), batch_size )
156
+ ]
157
+ tasks = [process_llm_record_batch (batch ) for batch in record_dict_batches ]
158
+ await asyncio .gather (* tasks )
159
+
160
+ if data_type == "LLM_RESPONSE" :
161
+ asyncio .run (process_async_llm_calls (record_dict_list ))
162
+ requests .put (payload_url , json = calculated_attribute_by_record_id )
163
+ __print_progress (1.0 )
164
+ print ("Finished execution." )
165
+ else :
166
+ for record_dict in record_dict_list :
167
+ idx += 1
168
+ if idx % progress_size == 0 :
169
+ progress = round (idx / amount , 2 )
170
+ __print_progress (progress )
171
+ attr_value = attribute_calculators .ac (record_dict ["data" ])
172
+ if not check_data_type (attr_value ):
173
+ raise ValueError (
174
+ f"Attribute value `{ attr_value } ` is of type { type (attr_value )} , "
175
+ f"but data_type { data_type } requires "
176
+ f"{ str (py_data_types ) if len (py_data_types ) > 1 else str (py_data_types [0 ])} ."
177
+ )
178
+ calculated_attribute_by_record_id [record_dict ["id" ]] = attr_value
179
+ __print_progress (1.0 )
180
+ print ("Finished execution." )
181
+ requests .put (payload_url , json = calculated_attribute_by_record_id )
0 commit comments