Skip to content

Delta migration #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions alembic/versions/96fbb404381e_delta_indicators_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""delta indicators embeddings

Revision ID: 96fbb404381e
Revises: eb96f9b82cc1
Create Date: 2025-05-21 11:09:33.093313

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '96fbb404381e'
down_revision = 'eb96f9b82cc1'
branch_labels = None
depends_on = None


def upgrade():
connection = op.get_bind()
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('embedding', sa.Column('delta_full_recalculation_threshold', sa.Float(), nullable=True))
op.add_column('embedding', sa.Column('current_delta_record_count', sa.Integer(), nullable=True))

update_sql = """
UPDATE public.embedding
SET delta_full_recalculation_threshold = 0.5,
current_delta_record_count = 0
WHERE delta_full_recalculation_threshold IS NULL
"""
connection.execute(update_sql)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('embedding', 'current_delta_record_count')
op.drop_column('embedding', 'delta_full_recalculation_threshold')
# ### end Alembic commands ###
6 changes: 3 additions & 3 deletions api/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional
from starlette.endpoints import HTTPEndpoint
from starlette.responses import PlainTextResponse
from controller.embedding.manager import recreate_embeddings
from controller.embedding.manager import recreate_or_extend_embeddings

from controller.transfer.cognition import (
import_preparator as cognition_preparator,
Expand Down Expand Up @@ -165,7 +165,7 @@ def __recalculate_missing_attributes_and_embeddings(
project_id: str, user_id: str
) -> None:
__calculate_missing_attributes(project_id, user_id)
recreate_embeddings(project_id)
recreate_or_extend_embeddings(project_id)


def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
Expand Down Expand Up @@ -218,7 +218,7 @@ def __calculate_missing_attributes(project_id: str, user_id: str) -> None:
if current_att.state == enums.AttributeState.RUNNING.value:
continue
elif current_att.state == enums.AttributeState.INITIAL.value:
attribute_manager.calculate_user_attribute_all_records(
attribute_manager.calculate_user_attribute_missing_records(
project_id,
project.get_org_id(project_id),
user_id,
Expand Down
4 changes: 4 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
project as project_bo,
general,
)
from submodules.s3 import controller as s3
from submodules.model.models import (
Organization,
User,
Expand All @@ -29,8 +30,10 @@ def database_session() -> Iterator[None]:
@pytest.fixture(scope="session")
def org() -> Iterator[Organization]:
org_item = organization_bo.create(name="test_org", with_commit=True)
s3.create_bucket(str(org_item.id))
yield org_item
organization_bo.delete(org_item.id, with_commit=True)
s3.remove_bucket(str(org_item.id), True)


@pytest.fixture(scope="session")
Expand All @@ -47,6 +50,7 @@ def refinery_project(org: Organization, user: User) -> Iterator[RefineryProject]
name="test_project",
description="test_description",
created_by=user.id,
tokenizer="en_core_web_sm",
with_commit=True,
)
yield project_item
Expand Down
17 changes: 13 additions & 4 deletions controller/attribute/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __add_running_id(
general.remove_and_refresh_session(session_token)


def calculate_user_attribute_all_records(
def calculate_user_attribute_missing_records(
project_id: str,
org_id: str,
user_id: str,
Expand Down Expand Up @@ -285,7 +285,7 @@ def calculate_user_attribute_all_records(
project_id=project_id, message=f"calculate_attribute:started:{attribute_id}"
)
daemon.run_without_db_token(
__calculate_user_attribute_all_records,
__calculate_user_attribute_missing_records,
project_id,
org_id,
user_id,
Expand All @@ -294,7 +294,7 @@ def calculate_user_attribute_all_records(
)


def __calculate_user_attribute_all_records(
def __calculate_user_attribute_missing_records(
project_id: str,
org_id: str,
user_id: str,
Expand All @@ -303,9 +303,18 @@ def __calculate_user_attribute_all_records(
) -> None:
session_token = general.get_ctx_token()

all_records_count = record.count(project_id)
count_delta = record.count_missing_delta(project_id, attribute_id)

if count_delta != all_records_count:
doc_bin = util.prepare_delta_records_doc_bin(
attribute_id=attribute_id, project_id=project_id
)
else:
doc_bin = "docbin_full"
try:
calculated_attributes = util.run_attribute_calculation_exec_env(
attribute_id=attribute_id, project_id=project_id, doc_bin="docbin_full"
attribute_id=attribute_id, project_id=project_id, doc_bin=doc_bin
)
if not calculated_attributes:
__notify_attribute_calculation_failed(
Expand Down
19 changes: 19 additions & 0 deletions controller/attribute/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ def prepare_sample_records_doc_bin(
return prefixed_doc_bin


def prepare_delta_records_doc_bin(attribute_id: str, project_id: str) -> str:
missing_records = record.get_missing_delta_record_ids(project_id, attribute_id)

sample_records_doc_bin = tokenization.get_doc_bin_table_to_json(
project_id=project_id,
missing_columns=record.get_missing_columns_str(project_id),
record_ids=missing_records,
)
project_item = project.get(project_id)
org_id = str(project_item.organization_id)
prefixed_doc_bin = f"{attribute_id}_doc_bin.json"
s3.put_object(
org_id,
project_id + "/" + prefixed_doc_bin,
sample_records_doc_bin,
)
return prefixed_doc_bin


def test_openai_llm_connection(api_key: str, model: str, is_o_series: bool = False):
# more here: https://platform.openai.com/docs/api-reference/making-requests
headers = {
Expand Down
119 changes: 81 additions & 38 deletions controller/embedding/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
embedding,
agreement,
general,
record,
)
from submodules.model import daemon

Expand Down Expand Up @@ -99,7 +100,7 @@ def get_embedding_name(
return name


def recreate_embeddings(
def recreate_or_extend_embeddings(
project_id: str, embedding_ids: Optional[List[str]] = None, user_id: str = None
) -> None:
if not embedding_ids:
Expand All @@ -126,7 +127,9 @@ def recreate_embeddings(
embedding_item = embedding.get(project_id, embedding_id)
if not embedding_item:
continue
embedding_item = __recreate_embedding(project_id, embedding_id)
embedding_item = __recreate_or_extend_embedding(project_id, embedding_id)
if not embedding_item:
continue
new_id = embedding_item.id
time.sleep(2)
while True:
Expand Down Expand Up @@ -179,49 +182,78 @@ def __handle_failed_embedding(
general.commit()


def __recreate_embedding(project_id: str, embedding_id: str) -> Embedding:
old_embedding_item = embedding.get(project_id, embedding_id)
old_id = old_embedding_item.id
new_embedding_item = embedding.create(
project_id,
old_embedding_item.attribute_id,
old_embedding_item.name,
old_embedding_item.created_by,
enums.EmbeddingState.INITIALIZING.value,
type=old_embedding_item.type,
model=old_embedding_item.model,
platform=old_embedding_item.platform,
api_token=old_embedding_item.api_token,
filter_attributes=old_embedding_item.filter_attributes,
additional_data=old_embedding_item.additional_data,
with_commit=False,
)
embedding.delete(project_id, embedding_id, with_commit=False)
embedding.delete_tensors(embedding_id, with_commit=False)
general.commit()
def __recreate_or_extend_embedding(project_id: str, embedding_id: str) -> Embedding:

if (
new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value
or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value
or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value
):
agreement_item = agreement.get_by_xfkey(
project_id, old_id, enums.AgreementType.EMBEDDING.value
# check how many embeddings need to be recreated
old_embedding_item = embedding.get(project_id, embedding_id)
if not old_embedding_item:
return None
needs_full_recreation = False
if old_embedding_item.delta_full_recalculation_threshold == 0:
needs_full_recreation = True
elif old_embedding_item.delta_full_recalculation_threshold > 0:
already_deltaed = old_embedding_item.current_delta_record_count
full_count = record.count(project_id)
current_count = embedding.get_record_ids_count(embedding_id)
to_calc = full_count - current_count
if (
already_deltaed + to_calc
> old_embedding_item.delta_full_recalculation_threshold * full_count
):
# only to a full recreation if the delta is larger than the threshold
needs_full_recreation = True
else:
old_embedding_item.current_delta_record_count += to_calc
#
if needs_full_recreation:
old_id = old_embedding_item.id
new_embedding_item = embedding.create(
project_id,
old_embedding_item.attribute_id,
old_embedding_item.name,
old_embedding_item.created_by,
enums.EmbeddingState.INITIALIZING.value,
type=old_embedding_item.type,
model=old_embedding_item.model,
platform=old_embedding_item.platform,
api_token=old_embedding_item.api_token,
filter_attributes=old_embedding_item.filter_attributes,
additional_data=old_embedding_item.additional_data,
with_commit=False,
)
if not agreement_item:
new_embedding_item.state = enums.EmbeddingState.FAILED.value
general.commit()
raise ApiTokenImportError(
f"No agreement found for embedding {new_embedding_item.name}"
embedding.delete(project_id, embedding_id, with_commit=False)
embedding.delete_tensors(embedding_id, with_commit=False)
general.commit()

if (
new_embedding_item.platform == enums.EmbeddingPlatform.OPENAI.value
or new_embedding_item.platform == enums.EmbeddingPlatform.COHERE.value
or new_embedding_item.platform == enums.EmbeddingPlatform.AZURE.value
):
agreement_item = agreement.get_by_xfkey(
project_id, old_id, enums.AgreementType.EMBEDDING.value
)
agreement_item.xfkey = new_embedding_item.id
if not agreement_item:
new_embedding_item.state = enums.EmbeddingState.FAILED.value
general.commit()
raise ApiTokenImportError(
f"No agreement found for embedding {new_embedding_item.name}"
)
agreement_item.xfkey = new_embedding_item.id
general.commit()

connector.request_deleting_embedding(project_id, old_id)
else:
general.commit()

connector.request_deleting_embedding(project_id, old_id)
# request handles delta and full recreation
request_embedding_id = (
new_embedding_item.id if needs_full_recreation else embedding_id
)
daemon.run_without_db_token(
connector.request_embedding, project_id, new_embedding_item.id
connector.request_embedding, project_id, request_embedding_id
)
return new_embedding_item
return new_embedding_item if needs_full_recreation else old_embedding_item


def update_embedding_payload(
Expand Down Expand Up @@ -262,3 +294,14 @@ def update_label_payloads_for_neural_search(
embedding_ids=[str(e.id) for e in relevant_embeddings],
record_ids=record_ids,
)


def remove_tensors_by_record_ids(
project_id: str, record_ids: List[str], embedding_id: Optional[str] = None
) -> None:
if not record_ids:
return
if embedding_id:
embedding.delete_tensors_by_record_ids(project_id, record_ids, embedding_id)
else:
embedding.delete_tensors_by_record_ids(project_id, record_ids)
2 changes: 1 addition & 1 deletion controller/transfer/project_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def __post_processing_import_threaded(
if not data.get(
"embedding_tensors_data",
):
embedding_manager.recreate_embeddings(project_id, user_id=user_id)
embedding_manager.recreate_or_extend_embeddings(project_id, user_id=user_id)
else:
for old_id in embedding_ids:
embedding_manager.request_tensor_upload(
Expand Down
20 changes: 20 additions & 0 deletions controller/transfer/record_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from controller.upload_task import manager as upload_task_manager
from controller.tokenization import manager as token_manager
from controller.embedding import manager as embedding_manager
from util import file, security
from submodules.s3 import controller as s3
from submodules.model import enums, UploadTask, Attribute
Expand Down Expand Up @@ -171,6 +172,21 @@ def download_file(project_id: str, task: UploadTask) -> str:
def import_file(project_id: str, upload_task: UploadTask) -> None:
# load data from s3 and do transfer task/notification management
tmp_file_name, file_type = download_file(project_id, upload_task)
__import_file(project_id, upload_task, file_type, tmp_file_name)


def import_file_record_dict(
project_id: str, upload_task: UploadTask, records: List[Dict[str, Any]]
) -> None:
# load data from s3 and do transfer task/notification management
tmp_file_name = file.store_records_as_json_file(records)
file_type = "json"
__import_file(project_id, upload_task, file_type, tmp_file_name)


def __import_file(
project_id: str, upload_task: UploadTask, file_type: str, tmp_file_name: str
) -> None:
upload_task_manager.update_task(
project_id, upload_task.id, state=enums.UploadStates.IN_PROGRESS.value
)
Expand Down Expand Up @@ -287,6 +303,10 @@ def update_records_and_labels(
)
token_manager.delete_token_statistics(updated_records)
token_manager.delete_docbins(project_id, updated_records)
# remove embedding tensors if there are any to prep for delta migration
embedding_manager.remove_tensors_by_record_ids(
project_id, [str(r.id) for r in updated_records]
)
return remaining_records_data, remaining_labels_data


Expand Down
2 changes: 1 addition & 1 deletion fast_api/routes/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def calculate_attributes(
attribute_calculation_task_execution: AttributeCalculationTaskExecutionBody,
):
daemon.run_with_db_token(
attribute_manager.calculate_user_attribute_all_records,
attribute_manager.calculate_user_attribute_missing_records,
attribute_calculation_task_execution.project_id,
attribute_calculation_task_execution.organization_id,
attribute_calculation_task_execution.user_id,
Expand Down
5 changes: 4 additions & 1 deletion run-tests
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/bin/bash

docker exec -it refinery-gateway bash -c "cd /app && python -m pytest -v"

# add -s to see print statements
# add -v to see test names
docker exec -it refinery-gateway bash -c "cd /app && python -m pytest -v -s"
2 changes: 1 addition & 1 deletion submodules/model
Loading