Skip to content

Commit f1a9d7a

Browse files
refactor: rename functions and variables for clarity and consistency + cache at end
1 parent c41b868 commit f1a9d7a

File tree

1 file changed

+91
-56
lines changed

1 file changed

+91
-56
lines changed

run_ac.py

+91-56
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def __check_data_type_embedding_list(attr_value: Any) -> bool:
7373
return True
7474

7575

76-
def __print_progress(progress: float) -> None:
76+
def __print_progress_a2vybg(progress: float) -> None:
7777
print(f"progress: {progress}", flush=True)
7878

7979

80-
def load_data_dict(record: Dict[str, Any]) -> Dict[str, Any]:
81-
global vocab
80+
def load_data_dict_a2vybg(record: Dict[str, Any]) -> Dict[str, Any]:
81+
global vocab_a2vybg
8282

8383
if record["bytes"][:2] == "\\x":
8484
record["bytes"] = record["bytes"][2:]
@@ -87,7 +87,7 @@ def load_data_dict(record: Dict[str, Any]) -> Dict[str, Any]:
8787

8888
byte = bytes.fromhex(record["bytes"])
8989
doc_bin_loaded = DocBin().from_bytes(byte)
90-
docs = list(doc_bin_loaded.get_docs(vocab))
90+
docs = list(doc_bin_loaded.get_docs(vocab_a2vybg))
9191
data_dict = {}
9292
for col, doc in zip(record["columns"], docs):
9393
data_dict[col] = doc
@@ -99,62 +99,86 @@ def load_data_dict(record: Dict[str, Any]) -> Dict[str, Any]:
9999
return data_dict
100100

101101

102-
def parse_data_to_record_dict(
102+
def parse_data_to_record_dict_a2vybg(
103103
record_chunk: List[Dict[str, Any]]
104104
) -> List[Dict[str, Any]]:
105105
result = []
106106
for r in record_chunk:
107-
result.append({"id": r["record_id"], "data": load_data_dict(r)})
107+
result.append({"id": r["record_id"], "data": load_data_dict_a2vybg(r)})
108108
return result
109109

110110

111-
def save_ac_value(record_id: str, attr_value: Any) -> None:
112-
global calculated_attribute_by_record_id, processed_records, progress_size, amount
113-
global check_data_type, py_data_types, llm_ac_cache, llm_config_hash, cached_records
111+
def send_cache_to_object_storage_a2vybg():
112+
global llm_ac_cache_a2vybg, llm_config_hash_a2vybg, cached_records_a2vybg
113+
114+
if data_type == "LLM_RESPONSE" and "http" in CACHE_FILE_UPLOAD_LINK_A2VYBG:
115+
llm_ac_cache_a2vybg[llm_config_hash_a2vybg] = cached_records_a2vybg
116+
requests.put(CACHE_FILE_UPLOAD_LINK_A2VYBG, json=llm_ac_cache_a2vybg)
117+
118+
119+
def save_ac_value_a2vybg(record_id: str, attr_value: Any) -> None:
120+
global calculated_attribute_by_record_id_a2vybg, processed_records_a2vybg, progress_size_a2vybg, amount_a2vybg
121+
global check_data_type_a2vybg, py_data_types_a2vybg, llm_ac_cache_a2vybg, llm_config_hash_a2vybg, cached_records_a2vybg
114122
global CACHE_FILE_UPLOAD_LINK_A2VYBG
115123

116-
if not check_data_type(attr_value):
124+
if not check_data_type_a2vybg(attr_value):
117125
raise ValueError(
118126
f"Attribute value `{attr_value}` is of type {type(attr_value)}, "
119127
f"but data_type {data_type} requires "
120-
f"{str(py_data_types) if len(py_data_types) > 1 else str(py_data_types[0])}."
128+
f"{str(py_data_types_a2vybg) if len(py_data_types_a2vybg) > 1 else str(py_data_types_a2vybg[0])}."
121129
)
122130

123-
calculated_attribute_by_record_id[record_id] = attr_value
131+
calculated_attribute_by_record_id_a2vybg[record_id] = attr_value
124132

125-
if data_type == "LLM_RESPONSE" and "http" in CACHE_FILE_UPLOAD_LINK_A2VYBG:
126-
llm_ac_cache[llm_config_hash] = cached_records
127-
# TODO only save cache every few records to avoid request spamming
128-
requests.put(CACHE_FILE_UPLOAD_LINK_A2VYBG, json=llm_ac_cache)
129-
130-
processed_records = processed_records + 1
131-
if processed_records % progress_size == 0:
132-
__print_progress(round(processed_records / amount, 2))
133+
processed_records_a2vybg = processed_records_a2vybg + 1
134+
if processed_records_a2vybg % progress_size_a2vybg == 0:
135+
__print_progress_a2vybg(round(processed_records_a2vybg / amount_a2vybg, 2))
136+
if data_type == "LLM_RESPONSE" and processed_records_a2vybg % 250 == 0:
137+
send_cache_to_object_storage_a2vybg()
133138

134139

135-
def process_attribute_calculation(record_dict_list: List[Dict[str, Any]]) -> None:
140+
def process_attribute_calculation_a2vybg(
141+
record_dict_list: List[Dict[str, Any]]
142+
) -> None:
136143
for record_dict in record_dict_list:
137144
attr_value: Any = attribute_calculators.ac(record_dict["data"])
138-
save_ac_value(record_dict["id"], attr_value)
139-
145+
save_ac_value_a2vybg(record_dict["id"], attr_value)
140146

141-
async def process_llm_record_batch(record_dict_batch: List[Dict[str, Any]]) -> None:
142-
global DEFAULT_USER_PROMPT_A2VYBG, cached_records
143147

144-
for record_dict in record_dict_batch:
145-
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
146-
DEFAULT_USER_PROMPT_A2VYBG, record_dict
147-
)
148-
149-
attr_value: str = await attribute_calculators.ac(
150-
record_dict["data"], cached_records
151-
)
148+
def check_abort_status_a2vybg() -> bool:
149+
# function outside the async loop for reading always the freshest value
150+
global should_abort_a2vybg
151+
return should_abort_a2vybg
152152

153-
save_ac_value(record_dict["id"], attr_value)
154153

154+
async def process_llm_record_batch_a2vybg(
155+
record_dict_batch: List[Dict[str, Any]]
156+
) -> None:
157+
global DEFAULT_USER_PROMPT_A2VYBG, cached_records_a2vybg
155158

156-
async def process_async_llm_calls(record_dict_list: List[Dict[str, Any]]) -> None:
157-
global amount
159+
for record_dict in record_dict_batch:
160+
if check_abort_status_a2vybg():
161+
return
162+
try:
163+
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
164+
DEFAULT_USER_PROMPT_A2VYBG, record_dict
165+
)
166+
attr_value: str = await attribute_calculators.ac(
167+
record_dict["data"], cached_records_a2vybg
168+
)
169+
170+
save_ac_value_a2vybg(record_dict["id"], attr_value)
171+
except Exception as e:
172+
global should_abort_a2vybg
173+
should_abort_a2vybg = True
174+
print(f"Error in record {record_dict['data']['running_id']}: {str(e)}")
175+
return
176+
177+
178+
async def process_async_llm_calls_a2vybg(
179+
record_dict_list: List[Dict[str, Any]]
180+
) -> None:
181+
global amount_a2vybg
158182

159183
def make_batches(
160184
iterable: List[Any], size: int = 1
@@ -163,12 +187,15 @@ def make_batches(
163187
for ndx in range(0, length, size):
164188
yield iterable[ndx : min(ndx + size, length)]
165189

166-
batch_size = max(amount // int(attribute_calculators.NUM_WORKERS_A2VYBG), 1)
190+
batch_size = max(amount_a2vybg // int(attribute_calculators.NUM_WORKERS_A2VYBG), 1)
167191
tasks = [
168-
process_llm_record_batch(batch)
192+
process_llm_record_batch_a2vybg(batch)
169193
for batch in make_batches(record_dict_list, size=batch_size)
170194
]
171195
await asyncio.gather(*tasks)
196+
send_cache_to_object_storage_a2vybg()
197+
if check_abort_status_a2vybg():
198+
raise ValueError("Encountered error during LLM processing.")
172199

173200

174201
if __name__ == "__main__":
@@ -192,38 +219,46 @@ def make_batches(
192219
attribute_calculators, "CACHE_FILE_UPLOAD_LINK_A2VYBG", ""
193220
)
194221

195-
vocab = spacy.blank(iso2_code).vocab
222+
vocab_a2vybg = spacy.blank(iso2_code).vocab
223+
224+
should_abort_a2vybg = False
196225

197226
with open("docbin_full.json", "r") as infile:
198227
docbin_data = json.load(infile)
199228

200-
record_dict_list = parse_data_to_record_dict(docbin_data)
229+
record_dict_list = parse_data_to_record_dict_a2vybg(docbin_data)
201230

202-
py_data_types, check_data_type = get_check_data_type_function(data_type)
231+
py_data_types_a2vybg, check_data_type_a2vybg = get_check_data_type_function(
232+
data_type
233+
)
203234

204235
print("Running attribute calculation.")
205-
calculated_attribute_by_record_id = {}
206-
amount = len(record_dict_list)
207-
progress_size = min(
236+
calculated_attribute_by_record_id_a2vybg = {}
237+
amount_a2vybg = len(record_dict_list)
238+
progress_size_a2vybg = min(
208239
100,
209-
max(amount // int(getattr(attribute_calculators, "NUM_WORKERS_A2VYBG", 1)), 1),
240+
max(
241+
amount_a2vybg
242+
// int(getattr(attribute_calculators, "NUM_WORKERS_A2VYBG", 1)),
243+
1,
244+
),
210245
)
211-
processed_records = 0
212-
__print_progress(0.0)
246+
processed_records_a2vybg = 0
247+
__print_progress_a2vybg(0.0)
213248

214249
if data_type == "LLM_RESPONSE":
215-
llm_config = attribute_calculators.get_llm_config()
250+
llm_config = attribute_calculators.get_llm_config_a2vybg()
216251
if "http" in CACHE_ACCESS_LINK_A2VYBG:
217-
llm_ac_cache = requests.get(CACHE_ACCESS_LINK_A2VYBG).json()
252+
llm_ac_cache_a2vybg = requests.get(CACHE_ACCESS_LINK_A2VYBG).json()
218253
else:
219-
llm_ac_cache = {}
220-
llm_config_hash = md5(json.dumps(llm_config).encode()).hexdigest()
254+
llm_ac_cache_a2vybg = {}
255+
llm_config_hash_a2vybg = md5(json.dumps(llm_config).encode()).hexdigest()
221256

222-
cached_records = llm_ac_cache.get(llm_config_hash, {})
223-
asyncio.run(process_async_llm_calls(record_dict_list))
257+
cached_records_a2vybg = llm_ac_cache_a2vybg.get(llm_config_hash_a2vybg, {})
258+
asyncio.run(process_async_llm_calls_a2vybg(record_dict_list))
224259
else:
225-
process_attribute_calculation(record_dict_list)
260+
process_attribute_calculation_a2vybg(record_dict_list)
226261

227-
__print_progress(1.0)
262+
__print_progress_a2vybg(1.0)
228263
print("Finished execution.")
229-
requests.put(payload_url, json=calculated_attribute_by_record_id)
264+
requests.put(payload_url, json=calculated_attribute_by_record_id_a2vybg)

0 commit comments

Comments
 (0)