Skip to content

Commit c96f42d

Browse files
committed
perf: add llm ac caching
1 parent e12aea1 commit c96f42d

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

run_ac.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import requests
55
import spacy
66
import sys
7-
from mustache import prepare_and_render_mustache
87
from spacy.tokens import DocBin
8+
from hashlib import md5
9+
from mustache import prepare_and_render_mustache
910

1011

1112
def get_check_data_type_function(data_type: str) -> Tuple[List[Type], Callable]:
@@ -108,7 +109,9 @@ def parse_data_to_record_dict(
108109

109110

110111
def save_ac_value(record_id: str, attr_value: Any) -> None:
111-
global calculated_attribute_by_record_id, processed_records, progress_size, amount, check_data_type, py_data_types
112+
global calculated_attribute_by_record_id, processed_records, progress_size, amount
113+
global check_data_type, py_data_types, llm_ac_cache, llm_config_hash, cached_records
114+
global CACHE_FILE_UPLOAD_LINK_A2VYBG
112115

113116
if not check_data_type(attr_value):
114117
raise ValueError(
@@ -119,6 +122,10 @@ def save_ac_value(record_id: str, attr_value: Any) -> None:
119122

120123
calculated_attribute_by_record_id[record_id] = attr_value
121124

125+
if data_type == "LLM_RESPONSE":
126+
llm_ac_cache[llm_config_hash] = cached_records
127+
requests.put(CACHE_FILE_UPLOAD_LINK_A2VYBG, json=llm_ac_cache)
128+
122129
processed_records = processed_records + 1
123130
if processed_records % progress_size == 0:
124131
__print_progress(round(processed_records / amount, 2))
@@ -131,14 +138,23 @@ def process_attribute_calculation(record_dict_list: List[Dict[str, Any]]) -> Non
131138

132139

133140
async def process_llm_record_batch(record_dict_batch: List[Dict[str, Any]]) -> None:
134-
global DEFAULT_USER_PROMPT_A2VYBG
141+
global DEFAULT_USER_PROMPT_A2VYBG, cached_records
135142

136143
for record_dict in record_dict_batch:
137144
attribute_calculators.USER_PROMPT_A2VYBG = prepare_and_render_mustache(
138145
DEFAULT_USER_PROMPT_A2VYBG, record_dict
139146
)
140147

141-
attr_value: str = await attribute_calculators.ac(record_dict["data"])
148+
if record_dict["id"] in cached_records:
149+
print(
150+
"Using cached value for record",
151+
record_dict["data"]["running_id"],
152+
flush=True,
153+
)
154+
attr_value: str = cached_records[record_dict["id"]]
155+
else:
156+
attr_value: str = await attribute_calculators.ac(record_dict["data"])
157+
cached_records[record_dict["id"]] = attr_value
142158
save_ac_value(record_dict["id"], attr_value)
143159

144160

@@ -169,9 +185,17 @@ def make_batches(
169185
# the script `labeling_functions` does not exist. It will be inserted at runtime
170186
import attribute_calculators
171187

188+
# exists for both LLM playground and (run-on-10, run-all)
172189
DEFAULT_USER_PROMPT_A2VYBG = getattr(
173190
attribute_calculators, "USER_PROMPT_A2VYBG", None
174191
)
192+
# exists only for (run-on-10, run-all)
193+
CACHE_ACCESS_LINK_A2VYBG = getattr(
194+
attribute_calculators, "CACHE_ACCESS_LINK_A2VYBG", ""
195+
)
196+
CACHE_FILE_UPLOAD_LINK_A2VYBG = getattr(
197+
attribute_calculators, "CACHE_FILE_UPLOAD_LINK_A2VYBG", ""
198+
)
175199

176200
vocab = spacy.blank(iso2_code).vocab
177201

@@ -193,6 +217,11 @@ def make_batches(
193217
__print_progress(0.0)
194218

195219
if data_type == "LLM_RESPONSE":
220+
llm_config = attribute_calculators.get_llm_config()
221+
llm_ac_cache = requests.get(CACHE_ACCESS_LINK_A2VYBG).json()
222+
llm_config_hash = md5(json.dumps(llm_config).encode()).hexdigest()
223+
224+
cached_records = llm_ac_cache.get(llm_config_hash, {})
196225
asyncio.run(process_async_llm_calls(record_dict_list))
197226
else:
198227
process_attribute_calculation(record_dict_list)

0 commit comments

Comments
 (0)