Skip to content

Commit 8e59d8e

Browse files
anmarhindiLennartSchmidtKernlumburovskalina
authored
Question Playground (#152)
* Add initial table * Add evaluation models * add evaluation set accessors * add evaluation group accessors * add eval run bobj * add update bobj * by group id * deletion * Delete evaluation runs * Add playground question model * Add limit check for questions per project * Remove unused import * Remove unused fields * Remove unused import * Change model fields * Remove ifs during eval run create * Add optional state param * Add flush_or_commit * Rename to ids for readability * Rename delete_all func param to ids for readability * Add comment in PlaygroundQuestion model for extension properties * join query for eval group * Add join query for eval group id * Add check for duplicate questions * Deleting questions and order desc * Add question check --------- Co-authored-by: LennartSchmidtKern <lennart.schmidt@kern.ai> Co-authored-by: Lina <lina.lumburovska@kern.ai>
1 parent cf337c2 commit 8e59d8e

File tree

7 files changed

+402
-2
lines changed

7 files changed

+402
-2
lines changed

business_objects/evaluation_group.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import List
2+
3+
from ..models import EvaluationGroup
4+
from ..session import session
5+
from . import general
6+
7+
8+
def get(project_id: str, evaluation_group_id: str) -> EvaluationGroup:
9+
query = session.query(EvaluationGroup).filter(
10+
EvaluationGroup.project_id == project_id,
11+
EvaluationGroup.id == evaluation_group_id,
12+
)
13+
return query.first()
14+
15+
16+
def get_all(project_id: str) -> List[EvaluationGroup]:
17+
query = session.query(EvaluationGroup).filter(
18+
EvaluationGroup.project_id == project_id,
19+
)
20+
query = query.order_by(EvaluationGroup.name)
21+
return query.all()
22+
23+
24+
def create(
25+
project_id: str,
26+
name: str,
27+
created_by: str,
28+
evaluation_set_ids: List[str],
29+
with_commit: bool = False,
30+
) -> EvaluationGroup:
31+
eval_group = EvaluationGroup(
32+
project_id=project_id,
33+
name=name,
34+
created_by=created_by,
35+
evaluation_set_ids=evaluation_set_ids,
36+
)
37+
38+
general.add(eval_group, with_commit)
39+
40+
return eval_group
41+
42+
43+
def delete_all(project_id: str, group_ids: str, with_commit: bool = False):
44+
query = session.query(EvaluationGroup).filter(
45+
EvaluationGroup.project_id == project_id,
46+
EvaluationGroup.id.in_(group_ids),
47+
)
48+
query.delete(synchronize_session=False)
49+
general.flush_or_commit(with_commit)

business_objects/evaluation_run.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from typing import List, Optional
2+
3+
from submodules.model.enums import EvaluationRunState
4+
5+
from ..models import EvaluationRun
6+
from ..session import session
7+
from . import general
8+
9+
10+
def get(project_id: str, evaluation_run_id: str) -> EvaluationRun:
11+
query = session.query(EvaluationRun).filter(
12+
EvaluationRun.project_id == project_id,
13+
EvaluationRun.id == evaluation_run_id,
14+
)
15+
return query.first()
16+
17+
18+
def get_all_by_embedding_id(project_id: str, embedding_id: str) -> EvaluationRun:
19+
query = session.query(EvaluationRun).filter(
20+
EvaluationRun.project_id == project_id,
21+
EvaluationRun.embedding_id == embedding_id,
22+
)
23+
query = query.order_by(EvaluationRun.created_at.asc())
24+
return query.all()
25+
26+
27+
def get_all_by_evaluation_group_id(
28+
project_id: str, evaluation_group_id: str
29+
) -> EvaluationRun:
30+
query = session.query(EvaluationRun).filter(
31+
EvaluationRun.project_id == project_id,
32+
EvaluationRun.evaluation_group_id == evaluation_group_id,
33+
)
34+
query = query.order_by(EvaluationRun.created_at.asc())
35+
return query.all()
36+
37+
38+
def get_all(project_id: str) -> List[EvaluationRun]:
39+
query = session.query(EvaluationRun).filter(
40+
EvaluationRun.project_id == project_id,
41+
)
42+
query = query.order_by(EvaluationRun.created_at.asc())
43+
return query.all()
44+
45+
46+
def create(
47+
project_id: str,
48+
evaluation_group_id: str,
49+
created_by: str,
50+
embedding_id: str,
51+
state: EvaluationRunState,
52+
results: Optional[str] = None,
53+
meta_info: Optional[str] = None,
54+
with_commit: bool = False,
55+
) -> EvaluationRun:
56+
eval_run = EvaluationRun(
57+
evaluation_group_id=evaluation_group_id,
58+
created_by=created_by,
59+
project_id=project_id,
60+
embedding_id=embedding_id,
61+
state=state,
62+
results=results,
63+
meta_info=meta_info,
64+
)
65+
66+
general.add(eval_run, with_commit)
67+
68+
return eval_run
69+
70+
71+
def update(
72+
project_id: str,
73+
evaluation_run_id: str,
74+
state: Optional[EvaluationRunState] = None,
75+
results: Optional[str] = None,
76+
meta_info: Optional[str] = None,
77+
with_commit: bool = False,
78+
) -> EvaluationRun:
79+
eval_run: EvaluationRun = get(project_id, evaluation_run_id)
80+
if state is not None:
81+
eval_run.state = state
82+
if results is not None:
83+
eval_run.results = results
84+
if meta_info is not None:
85+
eval_run.meta_info = meta_info
86+
87+
general.flush_or_commit(with_commit)
88+
return eval_run
89+
90+
91+
def delete_all(project_id: str, run_ids: str, with_commit: bool = False):
92+
query = session.query(EvaluationRun).filter(
93+
EvaluationRun.project_id == project_id,
94+
EvaluationRun.id.in_(run_ids),
95+
)
96+
query.delete(synchronize_session=False)
97+
general.flush_or_commit(with_commit)

business_objects/evaluation_set.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import List
2+
3+
from submodules.model.util import prevent_sql_injection
4+
5+
from ..models import EvaluationSet
6+
from ..session import session
7+
from . import general
8+
9+
10+
def get(project_id: str, evaluation_set_id: str) -> EvaluationSet:
11+
query = session.query(EvaluationSet).filter(
12+
EvaluationSet.project_id == project_id,
13+
EvaluationSet.id == evaluation_set_id,
14+
)
15+
return query.first()
16+
17+
18+
def get_all(project_id: str) -> List[EvaluationSet]:
19+
query = session.query(EvaluationSet).filter(
20+
EvaluationSet.project_id == project_id,
21+
)
22+
query = query.order_by(EvaluationSet.question)
23+
return query.all()
24+
25+
26+
def get_by_evaluation_group_id(
27+
project_id: str, evaluation_group_id: str
28+
) -> List[EvaluationSet]:
29+
project_id = prevent_sql_injection(project_id, isinstance(project_id, str))
30+
evaluation_group_id = prevent_sql_injection(
31+
evaluation_group_id, isinstance(evaluation_group_id, str)
32+
)
33+
34+
query = f"""
35+
SELECT es.*
36+
FROM evaluation_group eg
37+
JOIN LATERAL jsonb_array_elements_text(eg.evaluation_set_ids::jsonb) AS elem(evaluation_set_id) ON TRUE
38+
JOIN evaluation_set es ON es.id = elem.evaluation_set_id::uuid
39+
WHERE eg.project_id = '{project_id}'
40+
AND eg.id = '{evaluation_group_id}'
41+
AND es.project_id = '{project_id}'
42+
ORDER BY es.question;
43+
"""
44+
45+
return general.execute_all(query)
46+
47+
48+
def create(
49+
project_id: str,
50+
question: str,
51+
created_by: str,
52+
record_ids: List[str],
53+
with_commit: bool = False,
54+
) -> EvaluationSet:
55+
eval_set = EvaluationSet(
56+
project_id=project_id,
57+
question=question,
58+
created_by=created_by,
59+
record_ids=record_ids,
60+
)
61+
general.add(eval_set, with_commit)
62+
return eval_set
63+
64+
65+
def delete_all(project_id: str, set_ids: List[str], with_commit: bool = False) -> None:
66+
session.query(EvaluationSet).filter(
67+
EvaluationSet.project_id == project_id,
68+
EvaluationSet.id.in_(set_ids),
69+
).delete(synchronize_session=False)
70+
general.flush_or_commit(with_commit)

business_objects/knowledge_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, List, Optional
1+
from typing import List, Optional
22

33
from ..models import KnowledgeBase
44
from ..exceptions import EntityAlreadyExistsException, EntityNotFoundException
@@ -14,7 +14,6 @@ def get(project_id: str, base_id: str) -> KnowledgeBase:
1414
)
1515

1616

17-
1817
def get_all_by_project_id(project_id: str) -> List[KnowledgeBase]:
1918
return (
2019
session.query(KnowledgeBase)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Any, List
2+
3+
from ..models import PlaygroundQuestion
4+
from ..session import session
5+
from . import general
6+
7+
8+
MAX_SAVED_QUESTIONS_HISTORY_PER_PROJECT = 100
9+
10+
11+
def get(project_id: str, question_id: str) -> PlaygroundQuestion:
12+
query = session.query(PlaygroundQuestion).filter(
13+
PlaygroundQuestion.project_id == project_id,
14+
PlaygroundQuestion.id == question_id,
15+
)
16+
return query.first()
17+
18+
19+
def get_all(project_id: str) -> List[PlaygroundQuestion]:
20+
query = session.query(PlaygroundQuestion).filter(
21+
PlaygroundQuestion.project_id == project_id,
22+
)
23+
query = query.order_by(PlaygroundQuestion.created_at.desc())
24+
return query.all()
25+
26+
27+
def create(
28+
project_id: str,
29+
question: str,
30+
with_commit: bool = False,
31+
) -> Any:
32+
33+
current_questions = (
34+
session.query(PlaygroundQuestion)
35+
.filter(
36+
PlaygroundQuestion.project_id == project_id,
37+
)
38+
.all()
39+
)
40+
41+
current_count = len(current_questions)
42+
43+
if current_count >= MAX_SAVED_QUESTIONS_HISTORY_PER_PROJECT:
44+
oldest = (
45+
session.query(PlaygroundQuestion)
46+
.filter(
47+
PlaygroundQuestion.project_id == project_id,
48+
)
49+
.order_by(PlaygroundQuestion.created_at.asc())
50+
.limit(current_count - MAX_SAVED_QUESTIONS_HISTORY_PER_PROJECT + 1)
51+
.all()
52+
)
53+
ids = [q.id for q in oldest]
54+
delete_all(project_id, ids, False)
55+
56+
if question:
57+
current_count_question = sum(
58+
1 for q in current_questions if str(q.question).lower() == question.lower()
59+
)
60+
61+
if current_count_question == 0:
62+
q = PlaygroundQuestion(
63+
project_id=project_id,
64+
question=question,
65+
)
66+
general.add(q, with_commit)
67+
return q
68+
69+
return None
70+
71+
72+
def delete_all(project_id: str, ids: List[str], with_commit: bool = False) -> None:
73+
session.query(PlaygroundQuestion).filter(
74+
PlaygroundQuestion.project_id == project_id,
75+
PlaygroundQuestion.id.in_(ids),
76+
).delete(synchronize_session=False)
77+
general.flush_or_commit(with_commit)
78+
79+
80+
def delete(project_id: str, id: str, with_commit: bool = True) -> None:
81+
session.query(PlaygroundQuestion).filter(
82+
PlaygroundQuestion.project_id == project_id,
83+
PlaygroundQuestion.id == id,
84+
).delete()
85+
general.flush_or_commit(with_commit)

enums.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ class Tablenames(Enum):
145145
PIPELINE_VERSION = (
146146
"pipeline_version" # dump of previous versions to easily jump between
147147
)
148+
EVALUATION_SET = "evaluation_set"
149+
EVALUATION_GROUP = "evaluation_group"
150+
EVALUATION_RUN = "evaluation_run"
151+
PLAYGROUND_QUESTION = "playground_question"
148152

149153
def snake_case_to_pascal_case(self):
150154
# the type name (written in PascalCase) of a table is needed to create backrefs
@@ -799,3 +803,10 @@ class ChangeAction(Enum):
799803
class PipelineVersionType(Enum):
800804
AUTO_SAVE = "AUTO_SAVE" # any save operation in relevant but only 10 per project
801805
NAMED_VERSION = "NAMED_VERSION" # any AUTO_SAVE that is considered worth keeping
806+
807+
808+
class EvaluationRunState(Enum):
809+
INITIATED = "INITIATED"
810+
RUNNING = "RUNNING"
811+
SUCCESS = "SUCCESS"
812+
FAILED = "FAILED"

0 commit comments

Comments
 (0)