Skip to content

Commit 42fa11b

Browse files
authored
Adds delta count methods (#169)
* Adds delta count methods * Model update for new fileds * Adds record selection for missing embedding tensors * Adds distinct count options * removes print * Sanatized embedding id
1 parent 9d0ecff commit 42fa11b

File tree

3 files changed

+103
-11
lines changed

3 files changed

+103
-11
lines changed

business_objects/embedding.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .. import enums
1010

1111
from ..util import prevent_sql_injection
12+
from sqlalchemy import distinct, func
1213

1314

1415
ALL_EMBEDDINGS_WHITELIST = {
@@ -587,6 +588,15 @@ def get_tensor_count(embedding_id: str) -> EmbeddingTensor:
587588
)
588589

589590

591+
def get_record_ids_count(embedding_id: str) -> int:
592+
# note that this is not the same as tensors since e.g. embedding lists are stored with sub_key
593+
return (
594+
session.query(func.count(distinct(models.EmbeddingTensor.record_id)))
595+
.filter(models.EmbeddingTensor.embedding_id == embedding_id)
596+
.scalar()
597+
)
598+
599+
590600
def get_tensor(
591601
embedding_id: str, record_id: Optional[str] = None, sub_key: Optional[int] = None
592602
) -> EmbeddingTensor:
@@ -782,6 +792,22 @@ def delete_tensors(embedding_id: str, with_commit: bool = False) -> None:
782792
general.flush_or_commit(with_commit)
783793

784794

795+
def delete_tensors_by_record_ids(
796+
project_id: str,
797+
record_ids: List[str],
798+
embedding_id: Optional[str] = None,
799+
with_commit: bool = False,
800+
) -> None:
801+
query = session.query(EmbeddingTensor).filter(
802+
EmbeddingTensor.project_id == project_id,
803+
EmbeddingTensor.record_id.in_(record_ids),
804+
)
805+
if embedding_id:
806+
query = query.filter(EmbeddingTensor.embedding_id == embedding_id)
807+
query.delete()
808+
general.flush_or_commit(with_commit)
809+
810+
785811
def delete_by_record_ids(
786812
project_id: str,
787813
embedding_id: str,

business_objects/record.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -451,29 +451,51 @@ def get_full_record_data_for_id_group(
451451

452452

453453
def get_attribute_data(
454-
project_id: str, attribute_name: str
454+
project_id: str,
455+
attribute_name: str,
456+
only_missing: bool = False,
457+
embedding_id: Optional[str] = None,
455458
) -> Tuple[List[str], List[str]]:
456459
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
457460
attribute_name = prevent_sql_injection(
458461
attribute_name, isinstance(attribute_name, str)
459462
)
463+
if embedding_id:
464+
embedding_id = prevent_sql_injection(
465+
embedding_id, isinstance(embedding_id, str)
466+
)
460467
query = None
461-
order = __get_order_by(project_id)
468+
order = __get_order_by(project_id, prefix="r.")
469+
join_extension, where_add = "", ""
470+
if only_missing:
471+
if not embedding_id:
472+
raise ValueError("embedding_id must be provided if only_missing is True")
473+
join_extension, where_add = (
474+
f"""
475+
LEFT JOIN embedding_tensor et
476+
ON et.project_id = r.project_id
477+
AND et.record_id = r.id
478+
AND et.project_id = '{project_id}' AND et.embedding_id = '{embedding_id}'
479+
""",
480+
"AND et.id IS NULL",
481+
)
462482
if attribute.get_by_name(project_id, attribute_name).data_type == "EMBEDDING_LIST":
463483
query = f"""
464484
SELECT id::TEXT || '@' || sub_key id, att AS "{attribute_name}"
465485
FROM (
466-
SELECT id, value as att, ordinality - 1 as sub_key
467-
FROM record
468-
cross join json_array_elements_text((data::JSON->'{attribute_name}')) with ordinality
469-
WHERE project_id = '{project_id}'
486+
SELECT r.id, value as att, ordinality - 1 as sub_key
487+
FROM record r
488+
{join_extension}
489+
cross join json_array_elements_text((r.data::JSON->'{attribute_name}')) with ordinality
490+
WHERE r.project_id = '{project_id}' {where_add}
470491
{order}
471492
)x """
472493
else:
473494
query = f"""
474-
SELECT id::TEXT, data::JSON->'{attribute_name}' AS "{attribute_name}"
475-
FROM record
476-
WHERE project_id = '{project_id}'
495+
SELECT r.id::TEXT, r.data::JSON->'{attribute_name}' AS "{attribute_name}"
496+
FROM record r
497+
{join_extension}
498+
WHERE r.project_id = '{project_id}' {where_add}
477499
{order}
478500
"""
479501
result = general.execute_all(query)
@@ -485,6 +507,43 @@ def count(project_id: str) -> int:
485507
return session.query(Record).filter(Record.project_id == project_id).count()
486508

487509

510+
def count_missing_delta(project_id: str, attribute_id: str) -> int:
511+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
512+
attribute_id = prevent_sql_injection(attribute_id, isinstance(attribute_id, str))
513+
query = f"""
514+
WITH n AS (
515+
SELECT NAME
516+
FROM attribute a
517+
WHERE id = '{attribute_id}'
518+
)
519+
SELECT COUNT(*)
520+
FROM record r, n
521+
WHERE r.project_id = '{project_id}'
522+
AND r.data->>n.name IS NULL
523+
"""
524+
value = general.execute_first(query)
525+
if not value or not value[0]:
526+
return 0
527+
return value[0]
528+
529+
530+
def get_missing_delta_record_ids(project_id: str, attribute_id: str) -> List[str]:
531+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
532+
attribute_id = prevent_sql_injection(attribute_id, isinstance(attribute_id, str))
533+
query = f"""
534+
WITH n AS (
535+
SELECT NAME
536+
FROM attribute a
537+
WHERE id = '{attribute_id}'
538+
)
539+
SELECT r.id::TEXT
540+
FROM record r, n
541+
WHERE r.project_id = '{project_id}'
542+
AND r.data->>n.name IS NULL
543+
"""
544+
return [row[0] for row in general.execute_all(query)]
545+
546+
488547
def count_attribute_list_entries(project_id: str, attribute_name: str) -> int:
489548
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
490549
attribute_name = prevent_sql_injection(
@@ -809,7 +868,7 @@ def get_tokenized_records_from_db(
809868
)
810869

811870

812-
def __get_order_by(project_id: str, first_x: int = 3) -> str:
871+
def __get_order_by(project_id: str, first_x: int = 3, prefix: str = "") -> str:
813872
query = f"""
814873
SELECT name, data_type
815874
FROM attribute a
@@ -823,7 +882,7 @@ def __get_order_by(project_id: str, first_x: int = 3) -> str:
823882
for x in values:
824883
if order != "":
825884
order += ", "
826-
tmp = f"data->>'{x.name}'"
885+
tmp = f"{prefix}data->>'{x.name}'"
827886

828887
r_id = attribute.get_running_id_name(project_id)
829888
if x.data_type == "INTEGER" and x.name == r_id:

models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,13 @@ class Embedding(Base):
807807
)
808808
additional_data = Column(JSON)
809809

810+
# threshold indicates when the embedding should be completely recalculated
811+
delta_full_recalculation_threshold = Column(Float, default=0.5)
812+
# holds the current number of records that were caluclated with the previous PCA if new records + current delta > threshold we recreate completely
813+
# note that this number can be higher than expected because of updated records being recalculated as well
814+
# meaning in theory if someone updates the same record over and over again at some point the full recalculation will be triggered
815+
current_delta_record_count = Column(Integer, default=0)
816+
810817

811818
class EmbeddingTensor(Base):
812819
__tablename__ = Tablenames.EMBEDDING_TENSOR.value

0 commit comments

Comments
 (0)