From 6f53efbd3d0e7072b41fc87ced3dd3922a68481a Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Wed, 21 May 2025 15:49:09 +0200 Subject: [PATCH 1/2] Adds delta logic --- controller.py | 40 ++++++++++++++++++++++++++++------------ submodules/model | 2 +- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/controller.py b/controller.py index 66c5ada..5b2087a 100644 --- a/controller.py +++ b/controller.py @@ -44,6 +44,7 @@ def generate_batches( attribute_values_raw: List[str], embedder: Transformer, attribute_name: str, + for_delta: bool = False, ) -> Iterator[Dict[List[str], List[Any]]]: length = len(record_ids) record_batches = [] @@ -61,7 +62,10 @@ def generate_batches( record_batches.append(record_ids_batch) document_batches.extend(documents) - embedding_batches = embedder.fit_transform(document_batches, as_generator=True) + if for_delta: + embedding_batches = embedder.transform(document_batches, as_generator=True) + else: + embedding_batches = embedder.fit_transform(document_batches, as_generator=True) for record_batch in record_batches: yield {"record_ids": record_batch, "embeddings": next(embedding_batches)} @@ -193,6 +197,13 @@ def run_encoding( initial_count = record.count_attribute_list_entries(project_id, attribute_name) else: initial_count = record.count(project_id) + + is_delta = False + # refinery gateway handles delta logic beforehand so if count is 0 we can be sure it's not a delta + if tensor_count := embedding.get_tensor_count(embedding_id) != 0: + is_delta = True + initial_count -= tensor_count + seed_str = embedding_name torch.manual_seed(zlib.adler32(bytes(seed_str, "utf-8"))) notification.create( @@ -214,15 +225,18 @@ def run_encoding( else: config_string = model - embedder = get_embedder( - project_id, - embedding_type, - iso2_code, - platform, - model, - api_token, - additional_data, - ) + if is_delta: + embedder = __setup_tmp_embedder(project_id, embedding_id) + else: + embedder = get_embedder( + project_id, + embedding_type, + iso2_code, + platform, + model, + api_token, + additional_data, + ) if not embedder: raise Exception( @@ -253,7 +267,7 @@ def run_encoding( try: record_ids, attribute_values_raw = record.get_attribute_data( - project_id, attribute_name + project_id, attribute_name, is_delta, embedding_id ) embedding.update_embedding_state_encoding( project_id, @@ -279,7 +293,8 @@ def run_encoding( True, ) send_project_update(project_id, f"notification_created:{user_id}", True) - embedding.delete_tensors(embedding_id, with_commit=True) + if not is_delta: + embedding.delete_tensors(embedding_id, with_commit=True) chunk = 0 embedding_canceled = False for pair in generate_batches( @@ -289,6 +304,7 @@ def run_encoding( attribute_values_raw, embedder, attribute_name, + for_delta=is_delta, ): if chunk % 10 == 0: session_token = general.remove_and_refresh_session(session_token, True) diff --git a/submodules/model b/submodules/model index 03717cc..e161a3e 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit 03717ccfa8b63cdc52bc286b6a1454fc389d6d73 +Subproject commit e161a3e06fb7dd16d04d0ee6afefe86cb93cd20d From 64bab6070b9d42d3649ad9eab108a0f13049118c Mon Sep 17 00:00:00 2001 From: JWittmeyer Date: Fri, 23 May 2025 12:57:41 +0200 Subject: [PATCH 2/2] Sbumodule update --- submodules/model | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submodules/model b/submodules/model index e161a3e..42fa11b 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit e161a3e06fb7dd16d04d0ee6afefe86cb93cd20d +Subproject commit 42fa11b74903c17533c765d89bef6d0fb9bbf93a