Skip to content

Commit 42edf3c

Browse files
Task master cancel (#71)
* cancel tokenization * cancel tokenization * model * returning tok id * model * error handling * model merge * model update
1 parent 1cf48de commit 42edf3c

File tree

4 files changed

+70
-21
lines changed

4 files changed

+70
-21
lines changed

app.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,33 @@ def tokenize_record(request: Request) -> responses.PlainTextResponse:
4040
def tokenize_calculated_attribute(
4141
request: AttributeTokenizationRequest,
4242
) -> responses.PlainTextResponse:
43-
task_manager.start_tokenization_task(
43+
record_tokenization_task_id = task_manager.start_tokenization_task(
4444
request.project_id,
4545
request.user_id,
4646
enums.TokenizationTaskTypes.ATTRIBUTE.value,
4747
request.include_rats,
4848
False,
4949
request.attribute_id,
5050
)
51-
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
51+
return responses.JSONResponse(
52+
content={"tokenization_task_id": str(record_tokenization_task_id)},
53+
status_code=status.HTTP_200_OK,
54+
)
5255

5356

5457
@app.post("/tokenize_project")
5558
def tokenize_project(request: Request) -> responses.PlainTextResponse:
56-
task_manager.start_tokenization_task(
59+
record_tokenization_task_id = task_manager.start_tokenization_task(
5760
request.project_id,
5861
request.user_id,
5962
enums.TokenizationTaskTypes.PROJECT.value,
6063
request.include_rats,
6164
request.only_uploaded_attributes,
6265
)
63-
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
66+
return responses.JSONResponse(
67+
content={"tokenization_task_id": str(record_tokenization_task_id)},
68+
status_code=status.HTTP_200_OK,
69+
)
6470

6571

6672
# rats = record_attribute_token_statistics

controller/task_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def start_tokenization_task(
7878
attribute_name,
7979
include_rats,
8080
)
81-
return status.HTTP_200_OK
81+
record_tokenization_task_id = None
82+
if task:
83+
record_tokenization_task_id = task.id
84+
return record_tokenization_task_id
8285

8386

8487
def start_rats_task(
@@ -87,7 +90,9 @@ def start_rats_task(
8790
only_uploaded_attributes: bool = False,
8891
attribute_id: Optional[str] = None,
8992
) -> int:
90-
if tokenization.is_doc_bin_creation_running_or_queued(project_id, only_running=True):
93+
if tokenization.is_doc_bin_creation_running_or_queued(
94+
project_id, only_running=True
95+
):
9196
# at the end of doc bin creation rats will be calculated
9297
return
9398

@@ -102,9 +107,11 @@ def start_rats_task(
102107
project_id,
103108
user_id,
104109
enums.TokenizerTask.TYPE_TOKEN_STATISTICS.value,
105-
scope=enums.RecordTokenizationScope.ATTRIBUTE.value
106-
if attribute_id
107-
else enums.RecordTokenizationScope.PROJECT.value,
110+
scope=(
111+
enums.RecordTokenizationScope.ATTRIBUTE.value
112+
if attribute_id
113+
else enums.RecordTokenizationScope.PROJECT.value
114+
),
108115
attribute_name=attribute_name,
109116
with_commit=True,
110117
)

controller/tokenization_manager.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,16 @@ def tokenize_calculated_attribute(
5454
record_tokenized_entries[x : x + chunk_size]
5555
for x in range(0, len(record_tokenized_entries), chunk_size)
5656
]
57+
tokenization_cancelled = False
5758
for idx, chunk in enumerate(chunks):
59+
record_tokenization_task = tokenization.get(project_id, task_id)
60+
if (
61+
not record_tokenization_task
62+
or record_tokenization_task.state
63+
== enums.TokenizerTask.STATE_FAILED.value
64+
):
65+
tokenization_cancelled = True
66+
break
5867
values = [
5968
add_attribute_to_docbin(tokenizer, record_tokenized_item)
6069
for record_tokenized_item in chunk
@@ -69,9 +78,20 @@ def tokenize_calculated_attribute(
6978
update_tokenization_progress(
7079
project_id, tokenization_task, progress_per_chunk
7180
)
72-
finalize_task(
73-
project_id, user_id, non_text_attributes, tokenization_task, include_rats
74-
)
81+
if not tokenization_cancelled:
82+
finalize_task(
83+
project_id,
84+
user_id,
85+
non_text_attributes,
86+
tokenization_task,
87+
include_rats,
88+
)
89+
else:
90+
send_websocket_update(
91+
project_id,
92+
False,
93+
["docbin", "state", str(record_tokenization_task.state)],
94+
)
7595
except Exception:
7696
__handle_error(project_id, user_id, task_id)
7797
finally:
@@ -116,7 +136,16 @@ def tokenize_initial_project(
116136
chunks = [
117137
records[x : x + chunk_size] for x in range(0, len(records), chunk_size)
118138
]
139+
tokenization_cancelled = False
119140
for idx, record_chunk in enumerate(chunks):
141+
record_tokenization_task = tokenization.get(project_id, task_id)
142+
if (
143+
not record_tokenization_task
144+
or record_tokenization_task.state
145+
== enums.TokenizerTask.STATE_FAILED.value
146+
):
147+
tokenization_cancelled = True
148+
break
120149
entries = []
121150
for record_item in record_chunk:
122151
if __remove_from_priority_queue(project_id, record_item.id):
@@ -131,14 +160,21 @@ def tokenize_initial_project(
131160
update_tokenization_progress(
132161
project_id, tokenization_task, progress_per_chunk
133162
)
134-
finalize_task(
135-
project_id,
136-
user_id,
137-
non_text_attributes,
138-
tokenization_task,
139-
include_rats,
140-
only_uploaded_attributes,
141-
)
163+
if not tokenization_cancelled:
164+
finalize_task(
165+
project_id,
166+
user_id,
167+
non_text_attributes,
168+
tokenization_task,
169+
include_rats,
170+
only_uploaded_attributes,
171+
)
172+
else:
173+
send_websocket_update(
174+
project_id,
175+
False,
176+
["docbin", "state", str(record_tokenization_task.state)],
177+
)
142178
except Exception:
143179
__handle_error(project_id, user_id, task_id)
144180
finally:

0 commit comments

Comments
 (0)