Skip to content

Commit 80d2a0c

Browse files
feat: implement asynchronous processing for LLM_RESPONSE data type
1 parent 73ac7ac commit 80d2a0c

File tree

1 file changed

+52
-20
lines changed

1 file changed

+52
-20
lines changed

run_ac.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import requests
33
import spacy
44
import sys
5+
import asyncio
56
from mustache import prepare_and_render_mustache
67
from spacy.tokens import DocBin
78

@@ -110,7 +111,8 @@ def parse_data_to_record_dict(record_chunk):
110111
# the script `labeling_functions` does not exist. It will be inserted at runtime
111112
import attribute_calculators
112113

113-
DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators.USER_PROMPT_A2VYBG
114+
if data_type == "LLM_RESPONSE":
115+
DEFAULT_USER_PROMPT_A2VYBG = attribute_calculators.USER_PROMPT_A2VYBG
114116

115117
vocab = spacy.blank(iso2_code).vocab
116118

@@ -127,23 +129,53 @@ def parse_data_to_record_dict(record_chunk):
127129
progress_size = 100
128130
amount = len(record_dict_list)
129131
__print_progress(0.0)
130-
for record_dict in record_dict_list:
131-
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
132-
DEFAULT_USER_PROMPT_A2VYBG, record_dict
133-
)
134-
135-
idx += 1
136-
if idx % progress_size == 0:
137-
progress = round(idx / amount, 2)
138-
__print_progress(progress)
139-
attr_value = attribute_calculators.ac(record_dict["data"])
140-
if not check_data_type(attr_value):
141-
raise ValueError(
142-
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
143-
f"but data_type {data_type} requires "
144-
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
132+
133+
async def process_llm_record_batch(record_dict_batch: list):
134+
"""Process a batch of record_dicts, writes results into shared var calculated_attribute_by_record_id."""
135+
136+
for record_dict in record_dict_batch:
137+
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
138+
DEFAULT_USER_PROMPT_A2VYBG, record_dict
145139
)
146-
calculated_attribute_by_record_id[record_dict["id"]] = attr_value
147-
__print_progress(1.0)
148-
print("Finished execution.")
149-
requests.put(payload_url, json=calculated_attribute_by_record_id)
140+
141+
attr_value: str = await attribute_calculators.ac(record_dict["data"])
142+
143+
if not check_data_type(attr_value):
144+
raise ValueError(
145+
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
146+
f"but data_type {data_type} requires "
147+
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
148+
)
149+
calculated_attribute_by_record_id[record_dict["id"]] = attr_value
150+
151+
async def process_async_llm_calls(record_dict_list):
152+
batch_size = len(record_dict_list) // int(attribute_calculators.NUM_WORKERS)
153+
record_dict_batches = [
154+
record_dict_list[i : i + batch_size]
155+
for i in range(0, len(record_dict_list), batch_size)
156+
]
157+
tasks = [process_llm_record_batch(batch) for batch in record_dict_batches]
158+
await asyncio.gather(*tasks)
159+
160+
if data_type == "LLM_RESPONSE":
161+
asyncio.run(process_async_llm_calls(record_dict_list))
162+
requests.put(payload_url, json=calculated_attribute_by_record_id)
163+
__print_progress(1.0)
164+
print("Finished execution.")
165+
else:
166+
for record_dict in record_dict_list:
167+
idx += 1
168+
if idx % progress_size == 0:
169+
progress = round(idx / amount, 2)
170+
__print_progress(progress)
171+
attr_value = attribute_calculators.ac(record_dict["data"])
172+
if not check_data_type(attr_value):
173+
raise ValueError(
174+
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
175+
f"but data_type {data_type} requires "
176+
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
177+
)
178+
calculated_attribute_by_record_id[record_dict["id"]] = attr_value
179+
__print_progress(1.0)
180+
print("Finished execution.")
181+
requests.put(payload_url, json=calculated_attribute_by_record_id)

0 commit comments

Comments
 (0)