diff --git a/alembic/versions/96fbb404381e_delta_indicators_embeddings.py b/alembic/versions/96fbb404381e_delta_indicators_embeddings.py new file mode 100644 index 00000000..ddcff207 --- /dev/null +++ b/alembic/versions/96fbb404381e_delta_indicators_embeddings.py @@ -0,0 +1,39 @@ +"""delta indicators embeddings + +Revision ID: 96fbb404381e +Revises: eb96f9b82cc1 +Create Date: 2025-05-21 11:09:33.093313 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '96fbb404381e' +down_revision = 'eb96f9b82cc1' +branch_labels = None +depends_on = None + + +def upgrade(): + connection = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('embedding', sa.Column('delta_full_recalculation_threshold', sa.Float(), nullable=True)) + op.add_column('embedding', sa.Column('current_delta_record_count', sa.Integer(), nullable=True)) + + update_sql = """ + UPDATE public.embedding + SET delta_full_recalculation_threshold = 0.5, + current_delta_record_count = 0 + WHERE delta_full_recalculation_threshold IS NULL + """ + connection.execute(update_sql) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('embedding', 'current_delta_record_count') + op.drop_column('embedding', 'delta_full_recalculation_threshold') + # ### end Alembic commands ### diff --git a/api/transfer.py b/api/transfer.py index a49ca365..d80b2b1c 100644 --- a/api/transfer.py +++ b/api/transfer.py @@ -4,7 +4,7 @@ from typing import Optional from starlette.endpoints import HTTPEndpoint from starlette.responses import PlainTextResponse -from controller.embedding.manager import recreate_embeddings +from controller.embedding.manager import recreate_or_extend_embeddings from controller.transfer.cognition import ( import_preparator as cognition_preparator, @@ -165,7 +165,7 @@ def __recalculate_missing_attributes_and_embeddings( project_id: str, user_id: str ) -> None: __calculate_missing_attributes(project_id, user_id) - recreate_embeddings(project_id) + recreate_or_extend_embeddings(project_id) def __calculate_missing_attributes(project_id: str, user_id: str) -> None: @@ -218,7 +218,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None: if current_att.state == enums.AttributeState.RUNNING.value: continue elif current_att.state == enums.AttributeState.INITIAL.value: - attribute_manager.calculate_user_attribute_all_records( + attribute_manager.calculate_user_attribute_missing_records( project_id, project.get_org_id(project_id), user_id, diff --git a/conftest.py b/conftest.py index 4bef7568..63a4bd04 100644 --- a/conftest.py +++ b/conftest.py @@ -12,6 +12,7 @@ project as project_bo, general, ) +from submodules.s3 import controller as s3 from submodules.model.models import ( Organization, User, @@ -29,8 +30,10 @@ def database_session() -> Iterator[None]: @pytest.fixture(scope="session") def org() -> Iterator[Organization]: org_item = organization_bo.create(name="test_org", with_commit=True) + s3.create_bucket(str(org_item.id)) yield org_item organization_bo.delete(org_item.id, with_commit=True) + s3.remove_bucket(str(org_item.id), True) @pytest.fixture(scope="session") @@ -47,6 +50,7 @@ def refinery_project(org: Organization, user: User) -> Iterator[RefineryProject] name="test_project", description="test_description", created_by=user.id, + tokenizer="en_core_web_sm", with_commit=True, ) yield project_item diff --git a/controller/attribute/manager.py b/controller/attribute/manager.py index d6d3bce0..5736494a 100644 --- a/controller/attribute/manager.py +++ b/controller/attribute/manager.py @@ -234,7 +234,7 @@ def __add_running_id( general.remove_and_refresh_session(session_token) -def calculate_user_attribute_all_records( +def calculate_user_attribute_missing_records( project_id: str, org_id: str, user_id: str, @@ -285,7 +285,7 @@ def calculate_user_attribute_all_records( project_id=project_id, message=f"calculate_attribute:started:{attribute_id}" ) daemon.run_without_db_token( - __calculate_user_attribute_all_records, + __calculate_user_attribute_missing_records, project_id, org_id, user_id, @@ -294,7 +294,7 @@ def calculate_user_attribute_all_records( ) -def __calculate_user_attribute_all_records( +def __calculate_user_attribute_missing_records( project_id: str, org_id: str, user_id: str, @@ -303,9 +303,18 @@ def __calculate_user_attribute_all_records( ) -> None: session_token = general.get_ctx_token() + all_records_count = record.count(project_id) + count_delta = record.count_missing_delta(project_id, attribute_id) + + if count_delta != all_records_count: + doc_bin = util.prepare_delta_records_doc_bin( + attribute_id=attribute_id, project_id=project_id + ) + else: + doc_bin = "docbin_full" try: calculated_attributes = util.run_attribute_calculation_exec_env( - attribute_id=attribute_id, project_id=project_id, doc_bin="docbin_full" + attribute_id=attribute_id, project_id=project_id, doc_bin=doc_bin ) if not calculated_attributes: __notify_attribute_calculation_failed( diff --git a/controller/attribute/util.py b/controller/attribute/util.py index b106855d..b715c9d4 100644 --- a/controller/attribute/util.py +++ b/controller/attribute/util.py @@ -64,10 +64,23 @@ def prepare_sample_records_doc_bin( ) -> str: sample_records = record.get_attribute_calculation_sample_records(project_id) + return __prepare_records_doc_bin( + attribute_id, project_id, record_ids or [r[0] for r in sample_records] + ) + + +def prepare_delta_records_doc_bin(attribute_id: str, project_id: str) -> str: + missing_records = record.get_missing_delta_record_ids(project_id, attribute_id) + return __prepare_records_doc_bin(attribute_id, project_id, missing_records) + + +def __prepare_records_doc_bin( + attribute_id: str, project_id: str, record_ids: List[str] +) -> str: sample_records_doc_bin = tokenization.get_doc_bin_table_to_json( project_id=project_id, missing_columns=record.get_missing_columns_str(project_id), - record_ids=record_ids or [r[0] for r in sample_records], + record_ids=record_ids, ) project_item = project.get(project_id) org_id = str(project_item.organization_id) diff --git a/controller/embedding/manager.py b/controller/embedding/manager.py index 6f5ee4a8..4a554ef5 100644 --- a/controller/embedding/manager.py +++ b/controller/embedding/manager.py @@ -13,6 +13,7 @@ embedding, agreement, general, + record, ) from submodules.model import daemon @@ -99,7 +100,7 @@ def get_embedding_name( return name -def recreate_embeddings( +def recreate_or_extend_embeddings( project_id: str, embedding_ids: Optional[List[str]] = None, user_id: str = None ) -> None: if not embedding_ids: @@ -126,7 +127,9 @@ def recreate_embeddings( embedding_item = embedding.get(project_id, embedding_id) if not embedding_item: continue - embedding_item = __recreate_embedding(project_id, embedding_id) + embedding_item = __recreate_or_extend_embedding(project_id, embedding_id) + if not embedding_item: + continue new_id = embedding_item.id time.sleep(2) while True: @@ -179,49 +182,77 @@ def __handle_failed_embedding( general.commit() -def __recreate_embedding(project_id: str, embedding_id: str) -> Embedding: - old_embedding_item = embedding.get(project_id, embedding_id) - old_id = old_embedding_item.id - new_embedding_item = embedding.create( - project_id, - old_embedding_item.attribute_id, - old_embedding_item.name, - old_embedding_item.created_by, - enums.EmbeddingState.INITIALIZING.value, - type=old_embedding_item.type, - model=old_embedding_item.model, - platform=old_embedding_item.platform, - api_token=old_embedding_item.api_token, - filter_attributes=old_embedding_item.filter_attributes, - additional_data=old_embedding_item.additional_data, - with_commit=False, - ) - embedding.delete(project_id, embedding_id, with_commit=False) - embedding.delete_tensors(embedding_id, with_commit=False) - general.commit() +def __recreate_or_extend_embedding(project_id: str, embedding_id: str) -> Embedding: - if ( - new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value - or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value - or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value - ): - agreement_item = agreement.get_by_xfkey( - project_id, old_id, enums.AgreementType.EMBEDDING.value + # check how many embeddings need to be recreated + old_embedding_item = embedding.get(project_id, embedding_id) + if not old_embedding_item: + return None + needs_full_recreation = False + if old_embedding_item.delta_full_recalculation_threshold == 0: + needs_full_recreation = True + elif old_embedding_item.delta_full_recalculation_threshold > 0: + already_deltaed = old_embedding_item.current_delta_record_count + full_count = record.count(project_id) + current_count = embedding.get_record_ids_count(embedding_id) + to_calc = full_count - current_count + if ( + already_deltaed + to_calc + > old_embedding_item.delta_full_recalculation_threshold * full_count + ): + # only to a full recreation if the delta is larger than the threshold + needs_full_recreation = True + else: + old_embedding_item.current_delta_record_count += to_calc + # + if needs_full_recreation: + new_embedding_item = embedding.create( + project_id, + old_embedding_item.attribute_id, + old_embedding_item.name, + old_embedding_item.created_by, + enums.EmbeddingState.INITIALIZING.value, + type=old_embedding_item.type, + model=old_embedding_item.model, + platform=old_embedding_item.platform, + api_token=old_embedding_item.api_token, + filter_attributes=old_embedding_item.filter_attributes, + additional_data=old_embedding_item.additional_data, + with_commit=False, ) - if not agreement_item: - new_embedding_item.state = enums.EmbeddingState.FAILED.value - general.commit() - raise ApiTokenImportError( - f"No agreement found for embedding {new_embedding_item.name}" + embedding.delete(project_id, embedding_id, with_commit=False) + embedding.delete_tensors(embedding_id, with_commit=False) + general.commit() + + if ( + new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value + or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value + or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value + ): + agreement_item = agreement.get_by_xfkey( + project_id, embedding_id, enums.AgreementType.EMBEDDING.value ) - agreement_item.xfkey = new_embedding_item.id + if not agreement_item: + new_embedding_item.state = enums.EmbeddingState.FAILED.value + general.commit() + raise ApiTokenImportError( + f"No agreement found for embedding {new_embedding_item.name}" + ) + agreement_item.xfkey = new_embedding_item.id + general.commit() + + connector.request_deleting_embedding(project_id, embedding_id) + else: general.commit() - connector.request_deleting_embedding(project_id, old_id) + # request handles delta and full recreation + request_embedding_id = ( + new_embedding_item.id if needs_full_recreation else embedding_id + ) daemon.run_without_db_token( - connector.request_embedding, project_id, new_embedding_item.id + connector.request_embedding, project_id, request_embedding_id ) - return new_embedding_item + return new_embedding_item if needs_full_recreation else old_embedding_item def update_embedding_payload( @@ -262,3 +293,11 @@ def update_label_payloads_for_neural_search( embedding_ids=[str(e.id) for e in relevant_embeddings], record_ids=record_ids, ) + + +def remove_tensors_by_record_ids( + project_id: str, record_ids: List[str], embedding_id: Optional[str] = None +) -> None: + if not record_ids: + return + embedding.delete_tensors_by_record_ids(project_id, record_ids, embedding_id) diff --git a/controller/transfer/project_transfer_manager.py b/controller/transfer/project_transfer_manager.py index f6cb50a6..d4bb9061 100644 --- a/controller/transfer/project_transfer_manager.py +++ b/controller/transfer/project_transfer_manager.py @@ -939,7 +939,7 @@ def __post_processing_import_threaded( if not data.get( "embedding_tensors_data", ): - embedding_manager.recreate_embeddings(project_id, user_id=user_id) + embedding_manager.recreate_or_extend_embeddings(project_id, user_id=user_id) else: for old_id in embedding_ids: embedding_manager.request_tensor_upload( diff --git a/controller/transfer/record_transfer_manager.py b/controller/transfer/record_transfer_manager.py index cf6f51fa..c5d0da66 100644 --- a/controller/transfer/record_transfer_manager.py +++ b/controller/transfer/record_transfer_manager.py @@ -20,6 +20,7 @@ from controller.upload_task import manager as upload_task_manager from controller.tokenization import manager as token_manager +from controller.embedding import manager as embedding_manager from util import file, security from submodules.s3 import controller as s3 from submodules.model import enums, UploadTask, Attribute @@ -171,6 +172,21 @@ def download_file(project_id: str, task: UploadTask) -> str: def import_file(project_id: str, upload_task: UploadTask) -> None: # load data from s3 and do transfer task/notification management tmp_file_name, file_type = download_file(project_id, upload_task) + __import_file(project_id, upload_task, file_type, tmp_file_name) + + +def import_file_record_dict( + project_id: str, upload_task: UploadTask, records: List[Dict[str, Any]] +) -> None: + # load data from s3 and do transfer task/notification management + tmp_file_name = file.store_records_as_json_file(records) + file_type = "json" + __import_file(project_id, upload_task, file_type, tmp_file_name) + + +def __import_file( + project_id: str, upload_task: UploadTask, file_type: str, tmp_file_name: str +) -> None: upload_task_manager.update_task( project_id, upload_task.id, state=enums.UploadStates.IN_PROGRESS.value ) @@ -287,6 +303,10 @@ def update_records_and_labels( ) token_manager.delete_token_statistics(updated_records) token_manager.delete_docbins(project_id, updated_records) + # remove embedding tensors if there are any to prep for delta migration + embedding_manager.remove_tensors_by_record_ids( + project_id, [str(r.id) for r in updated_records] + ) return remaining_records_data, remaining_labels_data diff --git a/fast_api/routes/task_execution.py b/fast_api/routes/task_execution.py index 581b3563..9f3ee6c1 100644 --- a/fast_api/routes/task_execution.py +++ b/fast_api/routes/task_execution.py @@ -25,7 +25,7 @@ def calculate_attributes( attribute_calculation_task_execution: AttributeCalculationTaskExecutionBody, ): daemon.run_with_db_token( - attribute_manager.calculate_user_attribute_all_records, + attribute_manager.calculate_user_attribute_missing_records, attribute_calculation_task_execution.project_id, attribute_calculation_task_execution.organization_id, attribute_calculation_task_execution.user_id, diff --git a/run-tests b/run-tests index 3452b470..f86e6ee6 100755 --- a/run-tests +++ b/run-tests @@ -1,3 +1,6 @@ #!/bin/bash -docker exec -it refinery-gateway bash -c "cd /app && python -m pytest -v" \ No newline at end of file + +# add -s to see print statements +# add -v to see test names +docker exec -it refinery-gateway bash -c "cd /app && python -m pytest -v -s" \ No newline at end of file diff --git a/submodules/model b/submodules/model index 9d0ecff3..42fa11b7 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit 9d0ecff36599cf1bc79da80c6db788ba36208171 +Subproject commit 42fa11b74903c17533c765d89bef6d0fb9bbf93a diff --git a/tests/fast_api/routes/test_project.py b/tests/fast_api/routes/test_project.py index e5378ace..f88131ee 100644 --- a/tests/fast_api/routes/test_project.py +++ b/tests/fast_api/routes/test_project.py @@ -1,7 +1,18 @@ from fastapi.testclient import TestClient -from submodules.model.models import Project as RefineryProject +from submodules.model.models import Project as RefineryProject, User -from submodules.model.business_objects import general +from controller.transfer import record_transfer_manager +from api import transfer as transfer_api +from controller.upload_task import manager as upload_task_manager +from submodules.model.business_objects import ( + general, + record as record_bo, + attribute as attribute_bo, + embedding as embedding_bo, +) +from submodules.model import enums +import json +import time def test_get_project_by_project_id( @@ -27,3 +38,114 @@ def test_update_project_name_description( general.refresh(refinery_project) assert refinery_project.name == "new_name" assert refinery_project.description == "new_description" + + +def test_upload_records_to_project( + client: TestClient, refinery_project: RefineryProject, user: User +): + upload_task = upload_task_manager.create_upload_task( + str(user.id), + str(refinery_project.id), + "dummy_file_name.csv", + "records", + "", + upload_type=enums.UploadTypes.DEFAULT.value, + key=None, + ) + record_transfer_manager.import_file_record_dict( + refinery_project.id, + upload_task, + [ + {"running_id": 1, "data": "hello world"}, + {"running_id": 2, "data": "hello world 2"}, + ], + ) + + assert record_bo.count(refinery_project.id) == 2 + attributes = attribute_bo.get_all(project_id=refinery_project.id) + assert len(attributes) == 2 + + for attribute in attributes: + if attribute.name != "running_id": + continue + att = attribute_bo.update( + refinery_project.id, attribute.id, is_primary_key=True, with_commit=True + ) + assert att is not None + assert att.is_primary_key is True + + +## in same file to ensure it's run in correct order +def test_create_embedding(client: TestClient, refinery_project: RefineryProject): + + att = attribute_bo.get_by_name(refinery_project.id, "data") + + assert att is not None + + response = client.post( + f"/api/v1/embedding/{refinery_project.id}/create-embedding", + json={ + "attribute_id": str(att.id), + "config": json.dumps( + { + "platform": "huggingface", + "termsText": None, + "termsAccepted": False, + "embeddingType": "ON_ATTRIBUTE", + "filterAttributes": [], + "model": "distilbert-base-uncased", + } + ), + }, + ) + + assert response.status_code == 200 + + for _ in range(20): + time.sleep(1) + all = embedding_bo.get_all_by_attribute_ids(refinery_project.id, [str(att.id)]) + if len(all) > 0: + break + assert len(all) > 0 + assert all[0].type == enums.EmbeddingType.ON_ATTRIBUTE.value + + # quite long since for a fresh start the model needs to be downloaded! + for _ in range(60): + time.sleep(1) + count = embedding_bo.get_tensor_count(all[0].id) + if count > 0: + break + assert count > 0 + + +def test_update_records_to_project( + client: TestClient, refinery_project: RefineryProject, user: User +): + + upload_task = upload_task_manager.create_upload_task( + str(user.id), + str(refinery_project.id), + "dummy_file_name.csv", + "records", + "", + upload_type=enums.UploadTypes.DEFAULT.value, + key=None, + ) + record_transfer_manager.import_file_record_dict( + refinery_project.id, + upload_task, + [{"running_id": 1, "data": "goodbye world"}], + ) + + assert record_bo.count(refinery_project.id) == 2 + all_records = record_bo.get_all(refinery_project.id) + + assert len(all_records) == 2 + assert any(r.data["data"] == "goodbye world" for r in all_records) + transfer_api.__recalculate_missing_attributes_and_embeddings( + project_id=refinery_project.id, user_id=user.id + ) + time.sleep(5) + emb = embedding_bo.get_all_embeddings_by_project_id(refinery_project.id) + assert len(emb) > 0 + assert emb[0].current_delta_record_count > 0 diff --git a/util/file.py b/util/file.py index 52a32f8e..97a49e2b 100644 --- a/util/file.py +++ b/util/file.py @@ -2,7 +2,7 @@ import os import pyminizip -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, List from zipfile import ZipFile from exceptions.exceptions import BadPasswordError @@ -41,6 +41,13 @@ def zip_to_json_file(zip_file_path: str, key: Optional[str] = None) -> str: return file_name +def store_records_as_json_file(record_data: List[Dict[str, Any]]) -> str: + file_name = __get_free_file_path("tmpdummy.json") + with open(file_name, "w") as f: + json.dump(record_data, f) + return file_name + + def file_to_zip(file_path: str, key: Optional[str] = None) -> Tuple[str, str]: zip_path = f"{file_path}.zip" pyminizip.compress(file_path, None, zip_path, key, 0)