44import requests
55import spacy
66import sys
7- from mustache import prepare_and_render_mustache
87from spacy .tokens import DocBin
8+ from hashlib import md5
9+ from mustache import prepare_and_render_mustache
910
1011
1112def get_check_data_type_function (data_type : str ) -> Tuple [List [Type ], Callable ]:
@@ -108,7 +109,9 @@ def parse_data_to_record_dict(
108109
109110
110111def save_ac_value (record_id : str , attr_value : Any ) -> None :
111- global calculated_attribute_by_record_id , processed_records , progress_size , amount , check_data_type , py_data_types
112+ global calculated_attribute_by_record_id , processed_records , progress_size , amount
113+ global check_data_type , py_data_types , llm_ac_cache , llm_config_hash , cached_records
114+ global CACHE_FILE_UPLOAD_LINK_A2VYBG
112115
113116 if not check_data_type (attr_value ):
114117 raise ValueError (
@@ -119,6 +122,10 @@ def save_ac_value(record_id: str, attr_value: Any) -> None:
119122
120123 calculated_attribute_by_record_id [record_id ] = attr_value
121124
125+ if data_type == "LLM_RESPONSE" :
126+ llm_ac_cache [llm_config_hash ] = cached_records
127+ requests .put (CACHE_FILE_UPLOAD_LINK_A2VYBG , json = llm_ac_cache )
128+
122129 processed_records = processed_records + 1
123130 if processed_records % progress_size == 0 :
124131 __print_progress (round (processed_records / amount , 2 ))
@@ -131,14 +138,23 @@ def process_attribute_calculation(record_dict_list: List[Dict[str, Any]]) -> Non
131138
132139
133140async def process_llm_record_batch (record_dict_batch : List [Dict [str , Any ]]) -> None :
134- global DEFAULT_USER_PROMPT_A2VYBG
141+ global DEFAULT_USER_PROMPT_A2VYBG , cached_records
135142
136143 for record_dict in record_dict_batch :
137144 attribute_calculators .USER_PROMPT_A2VYBG = prepare_and_render_mustache (
138145 DEFAULT_USER_PROMPT_A2VYBG , record_dict
139146 )
140147
141- attr_value : str = await attribute_calculators .ac (record_dict ["data" ])
148+ if record_dict ["id" ] in cached_records :
149+ print (
150+ "Using cached value for record" ,
151+ record_dict ["data" ]["running_id" ],
152+ flush = True ,
153+ )
154+ attr_value : str = cached_records [record_dict ["id" ]]
155+ else :
156+ attr_value : str = await attribute_calculators .ac (record_dict ["data" ])
157+ cached_records [record_dict ["id" ]] = attr_value
142158 save_ac_value (record_dict ["id" ], attr_value )
143159
144160
@@ -169,9 +185,17 @@ def make_batches(
169185 # the script `labeling_functions` does not exist. It will be inserted at runtime
170186 import attribute_calculators
171187
188+ # exists for both LLM playground and (run-on-10, run-all)
172189 DEFAULT_USER_PROMPT_A2VYBG = getattr (
173190 attribute_calculators , "USER_PROMPT_A2VYBG" , None
174191 )
192+ # exists only for (run-on-10, run-all)
193+ CACHE_ACCESS_LINK_A2VYBG = getattr (
194+ attribute_calculators , "CACHE_ACCESS_LINK_A2VYBG" , ""
195+ )
196+ CACHE_FILE_UPLOAD_LINK_A2VYBG = getattr (
197+ attribute_calculators , "CACHE_FILE_UPLOAD_LINK_A2VYBG" , ""
198+ )
175199
176200 vocab = spacy .blank (iso2_code ).vocab
177201
@@ -193,6 +217,11 @@ def make_batches(
193217 __print_progress (0.0 )
194218
195219 if data_type == "LLM_RESPONSE" :
220+ llm_config = attribute_calculators .get_llm_config ()
221+ llm_ac_cache = requests .get (CACHE_ACCESS_LINK_A2VYBG ).json ()
222+ llm_config_hash = md5 (json .dumps (llm_config ).encode ()).hexdigest ()
223+
224+ cached_records = llm_ac_cache .get (llm_config_hash , {})
196225 asyncio .run (process_async_llm_calls (record_dict_list ))
197226 else :
198227 process_attribute_calculation (record_dict_list )
0 commit comments