diff --git a/business_objects/evaluation_group.py b/business_objects/evaluation_group.py new file mode 100644 index 00000000..b35c9f75 --- /dev/null +++ b/business_objects/evaluation_group.py @@ -0,0 +1,49 @@ +from typing import List + +from ..models import EvaluationGroup +from ..session import session +from . import general + + +def get(project_id: str, evaluation_group_id: str) -> EvaluationGroup: + query = session.query(EvaluationGroup).filter( + EvaluationGroup.project_id == project_id, + EvaluationGroup.id == evaluation_group_id, + ) + return query.first() + + +def get_all(project_id: str) -> List[EvaluationGroup]: + query = session.query(EvaluationGroup).filter( + EvaluationGroup.project_id == project_id, + ) + query = query.order_by(EvaluationGroup.name) + return query.all() + + +def create( + project_id: str, + name: str, + created_by: str, + evaluation_set_ids: List[str], + with_commit: bool = False, +) -> EvaluationGroup: + eval_group = EvaluationGroup( + project_id=project_id, + name=name, + created_by=created_by, + evaluation_set_ids=evaluation_set_ids, + ) + + general.add(eval_group, with_commit) + + return eval_group + + +def delete_all(project_id: str, group_ids: str, with_commit: bool = False): + query = session.query(EvaluationGroup).filter( + EvaluationGroup.project_id == project_id, + EvaluationGroup.id.in_(group_ids), + ) + query.delete(synchronize_session=False) + general.flush_or_commit(with_commit) diff --git a/business_objects/evaluation_run.py b/business_objects/evaluation_run.py new file mode 100644 index 00000000..ce811263 --- /dev/null +++ b/business_objects/evaluation_run.py @@ -0,0 +1,97 @@ +from typing import List, Optional + +from submodules.model.enums import EvaluationRunState + +from ..models import EvaluationRun +from ..session import session +from . import general + + +def get(project_id: str, evaluation_run_id: str) -> EvaluationRun: + query = session.query(EvaluationRun).filter( + EvaluationRun.project_id == project_id, + EvaluationRun.id == evaluation_run_id, + ) + return query.first() + + +def get_all_by_embedding_id(project_id: str, embedding_id: str) -> EvaluationRun: + query = session.query(EvaluationRun).filter( + EvaluationRun.project_id == project_id, + EvaluationRun.embedding_id == embedding_id, + ) + query = query.order_by(EvaluationRun.created_at.asc()) + return query.all() + + +def get_all_by_evaluation_group_id( + project_id: str, evaluation_group_id: str +) -> EvaluationRun: + query = session.query(EvaluationRun).filter( + EvaluationRun.project_id == project_id, + EvaluationRun.evaluation_group_id == evaluation_group_id, + ) + query = query.order_by(EvaluationRun.created_at.asc()) + return query.all() + + +def get_all(project_id: str) -> List[EvaluationRun]: + query = session.query(EvaluationRun).filter( + EvaluationRun.project_id == project_id, + ) + query = query.order_by(EvaluationRun.created_at.asc()) + return query.all() + + +def create( + project_id: str, + evaluation_group_id: str, + created_by: str, + embedding_id: str, + state: EvaluationRunState, + results: Optional[str] = None, + meta_info: Optional[str] = None, + with_commit: bool = False, +) -> EvaluationRun: + eval_run = EvaluationRun( + evaluation_group_id=evaluation_group_id, + created_by=created_by, + project_id=project_id, + embedding_id=embedding_id, + state=state, + results=results, + meta_info=meta_info, + ) + + general.add(eval_run, with_commit) + + return eval_run + + +def update( + project_id: str, + evaluation_run_id: str, + state: Optional[EvaluationRunState] = None, + results: Optional[str] = None, + meta_info: Optional[str] = None, + with_commit: bool = False, +) -> EvaluationRun: + eval_run: EvaluationRun = get(project_id, evaluation_run_id) + if state is not None: + eval_run.state = state + if results is not None: + eval_run.results = results + if meta_info is not None: + eval_run.meta_info = meta_info + + general.flush_or_commit(with_commit) + return eval_run + + +def delete_all(project_id: str, run_ids: str, with_commit: bool = False): + query = session.query(EvaluationRun).filter( + EvaluationRun.project_id == project_id, + EvaluationRun.id.in_(run_ids), + ) + query.delete(synchronize_session=False) + general.flush_or_commit(with_commit) diff --git a/business_objects/evaluation_set.py b/business_objects/evaluation_set.py new file mode 100644 index 00000000..3dd7281d --- /dev/null +++ b/business_objects/evaluation_set.py @@ -0,0 +1,70 @@ +from typing import List + +from submodules.model.util import prevent_sql_injection + +from ..models import EvaluationSet +from ..session import session +from . import general + + +def get(project_id: str, evaluation_set_id: str) -> EvaluationSet: + query = session.query(EvaluationSet).filter( + EvaluationSet.project_id == project_id, + EvaluationSet.id == evaluation_set_id, + ) + return query.first() + + +def get_all(project_id: str) -> List[EvaluationSet]: + query = session.query(EvaluationSet).filter( + EvaluationSet.project_id == project_id, + ) + query = query.order_by(EvaluationSet.question) + return query.all() + + +def get_by_evaluation_group_id( + project_id: str, evaluation_group_id: str +) -> List[EvaluationSet]: + project_id = prevent_sql_injection(project_id, isinstance(project_id, str)) + evaluation_group_id = prevent_sql_injection( + evaluation_group_id, isinstance(evaluation_group_id, str) + ) + + query = f""" + SELECT es.* + FROM evaluation_group eg + JOIN LATERAL jsonb_array_elements_text(eg.evaluation_set_ids::jsonb) AS elem(evaluation_set_id) ON TRUE + JOIN evaluation_set es ON es.id = elem.evaluation_set_id::uuid + WHERE eg.project_id = '{project_id}' + AND eg.id = '{evaluation_group_id}' + AND es.project_id = '{project_id}' + ORDER BY es.question; + """ + + return general.execute_all(query) + + +def create( + project_id: str, + question: str, + created_by: str, + record_ids: List[str], + with_commit: bool = False, +) -> EvaluationSet: + eval_set = EvaluationSet( + project_id=project_id, + question=question, + created_by=created_by, + record_ids=record_ids, + ) + general.add(eval_set, with_commit) + return eval_set + + +def delete_all(project_id: str, set_ids: List[str], with_commit: bool = False) -> None: + session.query(EvaluationSet).filter( + EvaluationSet.project_id == project_id, + EvaluationSet.id.in_(set_ids), + ).delete(synchronize_session=False) + general.flush_or_commit(with_commit) diff --git a/business_objects/knowledge_base.py b/business_objects/knowledge_base.py index afb7efe7..4e41db9b 100644 --- a/business_objects/knowledge_base.py +++ b/business_objects/knowledge_base.py @@ -1,4 +1,4 @@ -from typing import List, List, Optional +from typing import List, Optional from ..models import KnowledgeBase from ..exceptions import EntityAlreadyExistsException, EntityNotFoundException @@ -14,7 +14,6 @@ def get(project_id: str, base_id: str) -> KnowledgeBase: ) - def get_all_by_project_id(project_id: str) -> List[KnowledgeBase]: return ( session.query(KnowledgeBase) diff --git a/business_objects/playground_question.py b/business_objects/playground_question.py new file mode 100644 index 00000000..195c174a --- /dev/null +++ b/business_objects/playground_question.py @@ -0,0 +1,85 @@ +from typing import Any, List + +from ..models import PlaygroundQuestion +from ..session import session +from . import general + + +MAX_SAVED_QUESTIONS_HISTORY_PER_PROJECT = 100 + + +def get(project_id: str, question_id: str) -> PlaygroundQuestion: + query = session.query(PlaygroundQuestion).filter( + PlaygroundQuestion.project_id == project_id, + PlaygroundQuestion.id == question_id, + ) + return query.first() + + +def get_all(project_id: str) -> List[PlaygroundQuestion]: + query = session.query(PlaygroundQuestion).filter( + PlaygroundQuestion.project_id == project_id, + ) + query = query.order_by(PlaygroundQuestion.created_at.desc()) + return query.all() + + +def create( + project_id: str, + question: str, + with_commit: bool = False, +) -> Any: + + current_questions = ( + session.query(PlaygroundQuestion) + .filter( + PlaygroundQuestion.project_id == project_id, + ) + .all() + ) + + current_count = len(current_questions) + + if current_count >= MAX_SAVED_QUESTIONS_HISTORY_PER_PROJECT: + oldest = ( + session.query(PlaygroundQuestion) + .filter( + PlaygroundQuestion.project_id == project_id, + ) + .order_by(PlaygroundQuestion.created_at.asc()) + .limit(current_count - MAX_SAVED_QUESTIONS_HISTORY_PER_PROJECT + 1) + .all() + ) + ids = [q.id for q in oldest] + delete_all(project_id, ids, False) + + if question: + current_count_question = sum( + 1 for q in current_questions if str(q.question).lower() == question.lower() + ) + + if current_count_question == 0: + q = PlaygroundQuestion( + project_id=project_id, + question=question, + ) + general.add(q, with_commit) + return q + + return None + + +def delete_all(project_id: str, ids: List[str], with_commit: bool = False) -> None: + session.query(PlaygroundQuestion).filter( + PlaygroundQuestion.project_id == project_id, + PlaygroundQuestion.id.in_(ids), + ).delete(synchronize_session=False) + general.flush_or_commit(with_commit) + + +def delete(project_id: str, id: str, with_commit: bool = True) -> None: + session.query(PlaygroundQuestion).filter( + PlaygroundQuestion.project_id == project_id, + PlaygroundQuestion.id == id, + ).delete() + general.flush_or_commit(with_commit) diff --git a/enums.py b/enums.py index 3e64a5ef..df408e25 100644 --- a/enums.py +++ b/enums.py @@ -145,6 +145,10 @@ class Tablenames(Enum): PIPELINE_VERSION = ( "pipeline_version" # dump of previous versions to easily jump between ) + EVALUATION_SET = "evaluation_set" + EVALUATION_GROUP = "evaluation_group" + EVALUATION_RUN = "evaluation_run" + PLAYGROUND_QUESTION = "playground_question" def snake_case_to_pascal_case(self): # the type name (written in PascalCase) of a table is needed to create backrefs @@ -799,3 +803,10 @@ class ChangeAction(Enum): class PipelineVersionType(Enum): AUTO_SAVE = "AUTO_SAVE" # any save operation in relevant but only 10 per project NAMED_VERSION = "NAMED_VERSION" # any AUTO_SAVE that is considered worth keeping + + +class EvaluationRunState(Enum): + INITIATED = "INITIATED" + RUNNING = "RUNNING" + SUCCESS = "SUCCESS" + FAILED = "FAILED" diff --git a/models.py b/models.py index 33b45928..3dd742aa 100644 --- a/models.py +++ b/models.py @@ -1862,3 +1862,92 @@ class CustomerButton(Base): ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), index=True, ) + + +class EvaluationSet(Base): + __tablename__ = Tablenames.EVALUATION_SET.value + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + question = Column(String) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + ) + record_ids = Column(JSON) + + +class EvaluationGroup(Base): + __tablename__ = Tablenames.EVALUATION_GROUP.value + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + ) + evaluation_set_ids = Column(JSON) + + +class EvaluationRun(Base): + __tablename__ = Tablenames.EVALUATION_RUN.value + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + evaluation_group_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.EVALUATION_GROUP.value}.id", ondelete="SET NULL"), + index=True, + ) + created_at = Column(DateTime, default=sql.func.now()) + created_by = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.USER.value}.id", ondelete="SET NULL"), + index=True, + ) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + ) + embedding_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.EMBEDDING.value}.id", ondelete="SET NULL"), + index=True, + ) + state = Column(String) + results = Column(JSON) + meta_info = Column(JSON) + + +class PlaygroundQuestion(Base): + __tablename__ = Tablenames.PLAYGROUND_QUESTION.value + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + question = Column(String) + created_at = Column(DateTime, default=sql.func.now()) + project_id = Column( + UUID(as_uuid=True), + ForeignKey(f"{Tablenames.PROJECT.value}.id", ondelete="CASCADE"), + index=True, + ) + """ + Playground question can be extended with the below properties to allow the following: + - User can see questions with specific results relating to the embedding used + - Can be used for comparison with new results using same question but different embedding + """ + # embedding_id = Column( + # UUID(as_uuid=True), + # ForeignKey(f"{Tablenames.EMBEDDING.value}.id", ondelete="SET NULL"), + # index=True, + # ) + # record_ids = Column(JSON) + # meta_info = Column(JSON)