Skip to content

Commit 202a062

Browse files
Merge remote-tracking branch 'origin/main' into communication-style-classifier
2 parents cc6f672 + 24359ab commit 202a062

File tree

8 files changed

+460
-307
lines changed

8 files changed

+460
-307
lines changed

classifiers/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from .lookup_lists import lookup_list
1111

12-
from .communication_style import (
13-
communication_style_classifier
12+
from .question_type import (
13+
question_type_classifier
1414
)
1515

1616
from .reference_quality import (
@@ -67,7 +67,7 @@
6767
bert_sentiment_german,
6868
special_character_classifier,
6969
chunked_sentence_complexity,
70-
communication_style_classifier
70+
question_type_classifier
7171
]:
7272
module_name = module.__name__.split(".")[-1]
7373
model_name = (
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Uses an `intfloat/multilingual-e5-small` model, which was finetuned on english and german examples of different question types. The model is hosted on Kern AIs own infrastructure and is meant to be used to classify text sequences by the labels `keyword-question`, `statement-question` or `interrogative-question`.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from pydantic import BaseModel
2+
import requests
3+
4+
INPUT_EXAMPLE = {
5+
"text": "Sushi restaurants Barcelona",
6+
"model_name": "KernAI/multilingual-e5-question-type",
7+
}
8+
9+
10+
class QuestionTypeClassifierModel(BaseModel):
11+
text: str
12+
model_name: str
13+
14+
class Config:
15+
schema_extra = {"example": INPUT_EXAMPLE}
16+
17+
18+
def question_type_classifier(req: QuestionTypeClassifierModel):
19+
"""Uses custom E5 model to classify the question type of a text"""
20+
payload = {
21+
"model_name": req.model_name,
22+
"text": req.text
23+
}
24+
response = requests.post("https://free.api.kern.ai/inference", json=payload)
25+
if response.ok:
26+
return {"question_type": response.json()["label"]}
27+
return response.raise_for_status()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
```python
2+
import requests
3+
4+
def question_type_classifier(text: str, model_name: str, request_url: str = "https://free.api.kern.ai/inference") -> str:
5+
"""
6+
@param text: text with a user query you want to classify
7+
@param model_name: Name of a model provided by Kern AI
8+
@param request_url: URL to the API endpoint of Kern AI
9+
@return: returns either 'keyword-question', 'interrogative-question' or 'statement-question'
10+
"""
11+
payload = {
12+
"model_name": model_name,
13+
"text": text
14+
}
15+
response = requests.post(request_url, json=payload)
16+
if response.ok:
17+
return response.json()["label"]
18+
return response.raise_for_status()
19+
20+
21+
# ↑ necessary bricks function
22+
# -----------------------------------------------------------------------------------------
23+
# ↓ example implementation
24+
25+
26+
model_name = "KernAI/multilingual-e5-question-type"
27+
28+
def example_integration():
29+
texts = ["Travel documents Germany", "Give me documents related to travel insurance.", "What is the content of these documents about?"]
30+
for text in texts:
31+
print(f"the question type of \"{text}\" is \"{question_type_classifier(text, model_name=model_name)}\"")
32+
33+
example_integration()
34+
```
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
```python
2+
import requests
3+
4+
ATTRIBUTE: str = "text" # only text attributes
5+
MODEL_NAME: str = "KernAI/multilingual-e5-question-type"
6+
REQUEST_URL: str = "https://free.api.kern.ai/inference"
7+
8+
def question_type_classifier(record):
9+
payload = {
10+
"model_name": MODEL_NAME,
11+
"text": record[ATTRIBUTE].text
12+
}
13+
response = requests.post(REQUEST_URL, json=payload)
14+
if response.ok:
15+
return response.json()["label"]
16+
```
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from util.configs import build_classifier_function_config
2+
from util.enums import State, RefineryDataType, BricksVariableType, SelectionType
3+
from . import question_type_classifier, INPUT_EXAMPLE
4+
5+
6+
def get_config():
7+
return build_classifier_function_config(
8+
function=question_type_classifier,
9+
input_example=INPUT_EXAMPLE,
10+
issue_id=344,
11+
tabler_icon="ZoomQuestion",
12+
min_refinery_version="1.7.0",
13+
state=State.PUBLIC.value,
14+
type="python_function",
15+
available_for=["refinery", "common"],
16+
part_of_group=[
17+
"question_type"
18+
], # first entry should be parent directory
19+
# bricks integrator information
20+
integrator_inputs={
21+
"name": "question_type_classifier",
22+
"refineryDataType": RefineryDataType.TEXT.value,
23+
"outputs": ["keyword-question", "statement-question", "interrogative-question"],
24+
"variables": {
25+
"ATTRIBUTE": {
26+
"selectionType": SelectionType.CHOICE.value,
27+
"addInfo": [
28+
BricksVariableType.ATTRIBUTE.value,
29+
BricksVariableType.GENERIC_STRING.value
30+
]
31+
},
32+
"MODEL_NAME": {
33+
"selectionType": SelectionType.STRING.value,
34+
"defaultValue": "KernAI/multilingual-e5-question-type",
35+
"addInfo": [
36+
BricksVariableType.GENERIC_STRING.value
37+
]
38+
},
39+
"REQUEST_URL": {
40+
"selectionType": SelectionType.STRING.value,
41+
"defaultValue": "https://free.api.kern.ai/inference",
42+
"addInfo": [
43+
BricksVariableType.GENERIC_STRING.value
44+
]
45+
}
46+
}
47+
}
48+
)

0 commit comments

Comments
 (0)