Skip to content

Commit d314536

Browse files
Delta migration (#303)
* Added full admin access table and endpoint * Added method for inviting users * Fixed refreshing kratos cache * Function for sending emails * SMTP variables and login * Removed unused code * PR comments * Added SSO provider to the invite users request * PR comments * Removed comments * PR comments * PR commentz * PR comments * Adds attribute delta migration * Adds embedding columns & delta logic * Removes prints * Test for delta * Extend wait time for embeddings * PR comments * Submodule update --------- Co-authored-by: Lina <lina.lumburovska@kern.ai>
1 parent 02bee88 commit d314536

File tree

13 files changed

+309
-53
lines changed

13 files changed

+309
-53
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""delta indicators embeddings
2+
3+
Revision ID: 96fbb404381e
4+
Revises: eb96f9b82cc1
5+
Create Date: 2025-05-21 11:09:33.093313
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '96fbb404381e'
14+
down_revision = 'eb96f9b82cc1'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
connection = op.get_bind()
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column('embedding', sa.Column('delta_full_recalculation_threshold', sa.Float(), nullable=True))
23+
op.add_column('embedding', sa.Column('current_delta_record_count', sa.Integer(), nullable=True))
24+
25+
update_sql = """
26+
UPDATE public.embedding
27+
SET delta_full_recalculation_threshold = 0.5,
28+
current_delta_record_count = 0
29+
WHERE delta_full_recalculation_threshold IS NULL
30+
"""
31+
connection.execute(update_sql)
32+
# ### end Alembic commands ###
33+
34+
35+
def downgrade():
36+
# ### commands auto generated by Alembic - please adjust! ###
37+
op.drop_column('embedding', 'current_delta_record_count')
38+
op.drop_column('embedding', 'delta_full_recalculation_threshold')
39+
# ### end Alembic commands ###

api/transfer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional
55
from starlette.endpoints import HTTPEndpoint
66
from starlette.responses import PlainTextResponse
7-
from controller.embedding.manager import recreate_embeddings
7+
from controller.embedding.manager import recreate_or_extend_embeddings
88

99
from controller.transfer.cognition import (
1010
import_preparator as cognition_preparator,
@@ -165,7 +165,7 @@ def __recalculate_missing_attributes_and_embeddings(
165165
project_id: str, user_id: str
166166
) -> None:
167167
__calculate_missing_attributes(project_id, user_id)
168-
recreate_embeddings(project_id)
168+
recreate_or_extend_embeddings(project_id)
169169

170170

171171
def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
@@ -218,7 +218,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
218218
if current_att.state == enums.AttributeState.RUNNING.value:
219219
continue
220220
elif current_att.state == enums.AttributeState.INITIAL.value:
221-
attribute_manager.calculate_user_attribute_all_records(
221+
attribute_manager.calculate_user_attribute_missing_records(
222222
project_id,
223223
project.get_org_id(project_id),
224224
user_id,

conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
project as project_bo,
1313
general,
1414
)
15+
from submodules.s3 import controller as s3
1516
from submodules.model.models import (
1617
Organization,
1718
User,
@@ -29,8 +30,10 @@ def database_session() -> Iterator[None]:
2930
@pytest.fixture(scope="session")
3031
def org() -> Iterator[Organization]:
3132
org_item = organization_bo.create(name="test_org", with_commit=True)
33+
s3.create_bucket(str(org_item.id))
3234
yield org_item
3335
organization_bo.delete(org_item.id, with_commit=True)
36+
s3.remove_bucket(str(org_item.id), True)
3437

3538

3639
@pytest.fixture(scope="session")
@@ -47,6 +50,7 @@ def refinery_project(org: Organization, user: User) -> Iterator[RefineryProject]
4750
name="test_project",
4851
description="test_description",
4952
created_by=user.id,
53+
tokenizer="en_core_web_sm",
5054
with_commit=True,
5155
)
5256
yield project_item

controller/attribute/manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __add_running_id(
234234
general.remove_and_refresh_session(session_token)
235235

236236

237-
def calculate_user_attribute_all_records(
237+
def calculate_user_attribute_missing_records(
238238
project_id: str,
239239
org_id: str,
240240
user_id: str,
@@ -285,7 +285,7 @@ def calculate_user_attribute_all_records(
285285
project_id=project_id, message=f"calculate_attribute:started:{attribute_id}"
286286
)
287287
daemon.run_without_db_token(
288-
__calculate_user_attribute_all_records,
288+
__calculate_user_attribute_missing_records,
289289
project_id,
290290
org_id,
291291
user_id,
@@ -294,7 +294,7 @@ def calculate_user_attribute_all_records(
294294
)
295295

296296

297-
def __calculate_user_attribute_all_records(
297+
def __calculate_user_attribute_missing_records(
298298
project_id: str,
299299
org_id: str,
300300
user_id: str,
@@ -303,9 +303,18 @@ def __calculate_user_attribute_all_records(
303303
) -> None:
304304
session_token = general.get_ctx_token()
305305

306+
all_records_count = record.count(project_id)
307+
count_delta = record.count_missing_delta(project_id, attribute_id)
308+
309+
if count_delta != all_records_count:
310+
doc_bin = util.prepare_delta_records_doc_bin(
311+
attribute_id=attribute_id, project_id=project_id
312+
)
313+
else:
314+
doc_bin = "docbin_full"
306315
try:
307316
calculated_attributes = util.run_attribute_calculation_exec_env(
308-
attribute_id=attribute_id, project_id=project_id, doc_bin="docbin_full"
317+
attribute_id=attribute_id, project_id=project_id, doc_bin=doc_bin
309318
)
310319
if not calculated_attributes:
311320
__notify_attribute_calculation_failed(

controller/attribute/util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,23 @@ def prepare_sample_records_doc_bin(
6464
) -> str:
6565
sample_records = record.get_attribute_calculation_sample_records(project_id)
6666

67+
return __prepare_records_doc_bin(
68+
attribute_id, project_id, record_ids or [r[0] for r in sample_records]
69+
)
70+
71+
72+
def prepare_delta_records_doc_bin(attribute_id: str, project_id: str) -> str:
73+
missing_records = record.get_missing_delta_record_ids(project_id, attribute_id)
74+
return __prepare_records_doc_bin(attribute_id, project_id, missing_records)
75+
76+
77+
def __prepare_records_doc_bin(
78+
attribute_id: str, project_id: str, record_ids: List[str]
79+
) -> str:
6780
sample_records_doc_bin = tokenization.get_doc_bin_table_to_json(
6881
project_id=project_id,
6982
missing_columns=record.get_missing_columns_str(project_id),
70-
record_ids=record_ids or [r[0] for r in sample_records],
83+
record_ids=record_ids,
7184
)
7285
project_item = project.get(project_id)
7386
org_id = str(project_item.organization_id)

controller/embedding/manager.py

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
embedding,
1414
agreement,
1515
general,
16+
record,
1617
)
1718
from submodules.model import daemon
1819

@@ -99,7 +100,7 @@ def get_embedding_name(
99100
return name
100101

101102

102-
def recreate_embeddings(
103+
def recreate_or_extend_embeddings(
103104
project_id: str, embedding_ids: Optional[List[str]] = None, user_id: str = None
104105
) -> None:
105106
if not embedding_ids:
@@ -126,7 +127,9 @@ def recreate_embeddings(
126127
embedding_item = embedding.get(project_id, embedding_id)
127128
if not embedding_item:
128129
continue
129-
embedding_item = __recreate_embedding(project_id, embedding_id)
130+
embedding_item = __recreate_or_extend_embedding(project_id, embedding_id)
131+
if not embedding_item:
132+
continue
130133
new_id = embedding_item.id
131134
time.sleep(2)
132135
while True:
@@ -179,49 +182,77 @@ def __handle_failed_embedding(
179182
general.commit()
180183

181184

182-
def __recreate_embedding(project_id: str, embedding_id: str) -> Embedding:
183-
old_embedding_item = embedding.get(project_id, embedding_id)
184-
old_id = old_embedding_item.id
185-
new_embedding_item = embedding.create(
186-
project_id,
187-
old_embedding_item.attribute_id,
188-
old_embedding_item.name,
189-
old_embedding_item.created_by,
190-
enums.EmbeddingState.INITIALIZING.value,
191-
type=old_embedding_item.type,
192-
model=old_embedding_item.model,
193-
platform=old_embedding_item.platform,
194-
api_token=old_embedding_item.api_token,
195-
filter_attributes=old_embedding_item.filter_attributes,
196-
additional_data=old_embedding_item.additional_data,
197-
with_commit=False,
198-
)
199-
embedding.delete(project_id, embedding_id, with_commit=False)
200-
embedding.delete_tensors(embedding_id, with_commit=False)
201-
general.commit()
185+
def __recreate_or_extend_embedding(project_id: str, embedding_id: str) -> Embedding:
202186

203-
if (
204-
new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value
205-
or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value
206-
or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value
207-
):
208-
agreement_item = agreement.get_by_xfkey(
209-
project_id, old_id, enums.AgreementType.EMBEDDING.value
187+
# check how many embeddings need to be recreated
188+
old_embedding_item = embedding.get(project_id, embedding_id)
189+
if not old_embedding_item:
190+
return None
191+
needs_full_recreation = False
192+
if old_embedding_item.delta_full_recalculation_threshold == 0:
193+
needs_full_recreation = True
194+
elif old_embedding_item.delta_full_recalculation_threshold > 0:
195+
already_deltaed = old_embedding_item.current_delta_record_count
196+
full_count = record.count(project_id)
197+
current_count = embedding.get_record_ids_count(embedding_id)
198+
to_calc = full_count - current_count
199+
if (
200+
already_deltaed + to_calc
201+
> old_embedding_item.delta_full_recalculation_threshold * full_count
202+
):
203+
# only to a full recreation if the delta is larger than the threshold
204+
needs_full_recreation = True
205+
else:
206+
old_embedding_item.current_delta_record_count += to_calc
207+
#
208+
if needs_full_recreation:
209+
new_embedding_item = embedding.create(
210+
project_id,
211+
old_embedding_item.attribute_id,
212+
old_embedding_item.name,
213+
old_embedding_item.created_by,
214+
enums.EmbeddingState.INITIALIZING.value,
215+
type=old_embedding_item.type,
216+
model=old_embedding_item.model,
217+
platform=old_embedding_item.platform,
218+
api_token=old_embedding_item.api_token,
219+
filter_attributes=old_embedding_item.filter_attributes,
220+
additional_data=old_embedding_item.additional_data,
221+
with_commit=False,
210222
)
211-
if not agreement_item:
212-
new_embedding_item.state = enums.EmbeddingState.FAILED.value
213-
general.commit()
214-
raise ApiTokenImportError(
215-
f"No agreement found for embedding {new_embedding_item.name}"
223+
embedding.delete(project_id, embedding_id, with_commit=False)
224+
embedding.delete_tensors(embedding_id, with_commit=False)
225+
general.commit()
226+
227+
if (
228+
new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value
229+
or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value
230+
or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value
231+
):
232+
agreement_item = agreement.get_by_xfkey(
233+
project_id, embedding_id, enums.AgreementType.EMBEDDING.value
216234
)
217-
agreement_item.xfkey = new_embedding_item.id
235+
if not agreement_item:
236+
new_embedding_item.state = enums.EmbeddingState.FAILED.value
237+
general.commit()
238+
raise ApiTokenImportError(
239+
f"No agreement found for embedding {new_embedding_item.name}"
240+
)
241+
agreement_item.xfkey = new_embedding_item.id
242+
general.commit()
243+
244+
connector.request_deleting_embedding(project_id, embedding_id)
245+
else:
218246
general.commit()
219247

220-
connector.request_deleting_embedding(project_id, old_id)
248+
# request handles delta and full recreation
249+
request_embedding_id = (
250+
new_embedding_item.id if needs_full_recreation else embedding_id
251+
)
221252
daemon.run_without_db_token(
222-
connector.request_embedding, project_id, new_embedding_item.id
253+
connector.request_embedding, project_id, request_embedding_id
223254
)
224-
return new_embedding_item
255+
return new_embedding_item if needs_full_recreation else old_embedding_item
225256

226257

227258
def update_embedding_payload(
@@ -262,3 +293,11 @@ def update_label_payloads_for_neural_search(
262293
embedding_ids=[str(e.id) for e in relevant_embeddings],
263294
record_ids=record_ids,
264295
)
296+
297+
298+
def remove_tensors_by_record_ids(
299+
project_id: str, record_ids: List[str], embedding_id: Optional[str] = None
300+
) -> None:
301+
if not record_ids:
302+
return
303+
embedding.delete_tensors_by_record_ids(project_id, record_ids, embedding_id)

controller/transfer/project_transfer_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def __post_processing_import_threaded(
939939
if not data.get(
940940
"embedding_tensors_data",
941941
):
942-
embedding_manager.recreate_embeddings(project_id, user_id=user_id)
942+
embedding_manager.recreate_or_extend_embeddings(project_id, user_id=user_id)
943943
else:
944944
for old_id in embedding_ids:
945945
embedding_manager.request_tensor_upload(

controller/transfer/record_transfer_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from controller.upload_task import manager as upload_task_manager
2222
from controller.tokenization import manager as token_manager
23+
from controller.embedding import manager as embedding_manager
2324
from util import file, security
2425
from submodules.s3 import controller as s3
2526
from submodules.model import enums, UploadTask, Attribute
@@ -171,6 +172,21 @@ def download_file(project_id: str, task: UploadTask) -> str:
171172
def import_file(project_id: str, upload_task: UploadTask) -> None:
172173
# load data from s3 and do transfer task/notification management
173174
tmp_file_name, file_type = download_file(project_id, upload_task)
175+
__import_file(project_id, upload_task, file_type, tmp_file_name)
176+
177+
178+
def import_file_record_dict(
179+
project_id: str, upload_task: UploadTask, records: List[Dict[str, Any]]
180+
) -> None:
181+
# load data from s3 and do transfer task/notification management
182+
tmp_file_name = file.store_records_as_json_file(records)
183+
file_type = "json"
184+
__import_file(project_id, upload_task, file_type, tmp_file_name)
185+
186+
187+
def __import_file(
188+
project_id: str, upload_task: UploadTask, file_type: str, tmp_file_name: str
189+
) -> None:
174190
upload_task_manager.update_task(
175191
project_id, upload_task.id, state=enums.UploadStates.IN_PROGRESS.value
176192
)
@@ -287,6 +303,10 @@ def update_records_and_labels(
287303
)
288304
token_manager.delete_token_statistics(updated_records)
289305
token_manager.delete_docbins(project_id, updated_records)
306+
# remove embedding tensors if there are any to prep for delta migration
307+
embedding_manager.remove_tensors_by_record_ids(
308+
project_id, [str(r.id) for r in updated_records]
309+
)
290310
return remaining_records_data, remaining_labels_data
291311

292312

fast_api/routes/task_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def calculate_attributes(
2525
attribute_calculation_task_execution: AttributeCalculationTaskExecutionBody,
2626
):
2727
daemon.run_with_db_token(
28-
attribute_manager.calculate_user_attribute_all_records,
28+
attribute_manager.calculate_user_attribute_missing_records,
2929
attribute_calculation_task_execution.project_id,
3030
attribute_calculation_task_execution.organization_id,
3131
attribute_calculation_task_execution.user_id,

run-tests

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
#!/bin/bash
22

3-
docker exec -it refinery-gateway bash -c "cd /app && python -m pytest -v"
3+
4+
# add -s to see print statements
5+
# add -v to see test names
6+
docker exec -it refinery-gateway bash -c "cd /app && python -m pytest -v -s"

0 commit comments

Comments
 (0)