|
| 1 | +import traceback |
| 2 | + |
| 3 | +from submodules.model.cognition_objects import markdown_file, markdown_dataset |
| 4 | +from handler.tokenizer_handler import get_tokenizer |
| 5 | +from submodules.model.business_objects import general |
| 6 | +from submodules.model.enums import CognitionMarkdownFileState |
| 7 | +from spacy.language import Language |
| 8 | + |
| 9 | +SEGMENT_DIVIDER = "\n\n" |
| 10 | + |
| 11 | +def rework_markdown_file_content(org_id: str, file_id: str, step: str) -> bool: |
| 12 | + if step == "SEGMENT_SENTENCES": |
| 13 | + return __rework_segment_sentences(org_id, file_id) |
| 14 | + return True |
| 15 | + |
| 16 | + |
| 17 | +def __rework_segment_sentences(org_id: str, file_id: str) -> bool: |
| 18 | + markdown_file_item = markdown_file.get(org_id, file_id) |
| 19 | + if markdown_file_item is None: |
| 20 | + return False |
| 21 | + |
| 22 | + dataset_item = markdown_dataset.get(org_id, markdown_file_item.dataset_id) |
| 23 | + if dataset_item is None: |
| 24 | + return False |
| 25 | + content = markdown_file_item.content |
| 26 | + try: |
| 27 | + nlp = get_tokenizer(dataset_item.tokenizer) |
| 28 | + max_length = __lookup_final_max_length(nlp) |
| 29 | + # Split the content into smaller chunks if it's too large |
| 30 | + if __utf8len(content) > max_length: |
| 31 | + chunks = __chunk_text_on_bytes(content,max_length - 100) |
| 32 | + processed_chunks = [] |
| 33 | + |
| 34 | + for chunk in chunks: |
| 35 | + doc = nlp(chunk) |
| 36 | + processed_chunk = SEGMENT_DIVIDER.join( |
| 37 | + [sentence for sentence in __segment_sentences(doc)] |
| 38 | + ) |
| 39 | + processed_chunks.append(processed_chunk) |
| 40 | + |
| 41 | + content = SEGMENT_DIVIDER.join(processed_chunks) |
| 42 | + else: |
| 43 | + doc = nlp(content) |
| 44 | + content = SEGMENT_DIVIDER.join([sentence for sentence in __segment_sentences(doc)]) |
| 45 | + markdown_file_item.content = content |
| 46 | + general.commit() |
| 47 | + return True |
| 48 | + except Exception: |
| 49 | + full_traceback = traceback.format_exc() |
| 50 | + print(full_traceback, flush=True) |
| 51 | + markdown_file.update( |
| 52 | + org_id=org_id, |
| 53 | + markdown_file_id=file_id, |
| 54 | + state=CognitionMarkdownFileState.FAILED.value, |
| 55 | + error=full_traceback, # Store the full stack trace instead of just the error message |
| 56 | + ) |
| 57 | + return False |
| 58 | + |
| 59 | + |
| 60 | +# custom segmentation rule to build very likely sentences from chunk of text |
| 61 | +def __segment_sentences(doc: Language): |
| 62 | + sentences = [] |
| 63 | + current_sentence = None |
| 64 | + for sent in doc.sents: |
| 65 | + if len(sent.text.strip()) == 0: |
| 66 | + continue |
| 67 | + last_char = sent.text.strip()[-1] |
| 68 | + |
| 69 | + if current_sentence is None: |
| 70 | + current_sentence = sent.text |
| 71 | + else: |
| 72 | + current_sentence += " " + sent.text |
| 73 | + |
| 74 | + if last_char in [".", ";", "?", "!"]: |
| 75 | + sentences.append(current_sentence) |
| 76 | + current_sentence = None |
| 77 | + |
| 78 | + if current_sentence is not None: |
| 79 | + sentences.append(current_sentence) |
| 80 | + return sentences |
| 81 | + |
| 82 | + |
| 83 | +def __chunk_text(text: str, chunk_size: int = 1_000_000): |
| 84 | + return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] |
| 85 | + |
| 86 | +# splits not after x bytes but ensures that max x bytes are used without destroying the characters |
| 87 | +def __chunk_text_on_bytes(text: str, max_chunk_size: int = 1_000_000): |
| 88 | + factor = len(text) / __utf8len(text) |
| 89 | + increase_by = int(max(min(max_chunk_size*.1,10),1)) |
| 90 | + initial_size_guess = int(max(max_chunk_size * factor - 10,1)) |
| 91 | + final_list = [] |
| 92 | + remaining = text |
| 93 | + while len(remaining): |
| 94 | + part = remaining[:initial_size_guess] |
| 95 | + if __utf8len(part) > max_chunk_size: |
| 96 | + initial_size_guess = max(initial_size_guess - min(max_chunk_size *.001,10),1) |
| 97 | + continue |
| 98 | + cut_after = initial_size_guess |
| 99 | + while __utf8len(part) < max_chunk_size and part != remaining: |
| 100 | + cut_after = min(len(remaining), cut_after+increase_by) |
| 101 | + part = remaining[:cut_after] |
| 102 | + |
| 103 | + if __utf8len(part) > max_chunk_size: |
| 104 | + cut_after-=increase_by |
| 105 | + final_list.append(remaining[:cut_after]) |
| 106 | + remaining = remaining[cut_after:] |
| 107 | + |
| 108 | + return final_list |
| 109 | + |
| 110 | + |
| 111 | + |
| 112 | +MAX_LENGTH_OVERWRITE = { |
| 113 | + # japanese has a max length restriction by sudachi so the spacy max_length only applies if < sudachi |
| 114 | + "ja":49149 |
| 115 | +} |
| 116 | + |
| 117 | +def __lookup_final_max_length(nlp:Language) -> int: |
| 118 | + overwrite = MAX_LENGTH_OVERWRITE.get(nlp.meta["lang"]) |
| 119 | + |
| 120 | + if overwrite and overwrite < nlp.max_length: |
| 121 | + return overwrite |
| 122 | + return nlp.max_length |
| 123 | + |
| 124 | + |
| 125 | +# note that "H" uses up 1 byte while "私" takes 3 bytes |
| 126 | +# len(s) would still give 1 but this runs into issues for reserved/allocated spacy memory |
| 127 | +def __utf8len(s:str): |
| 128 | + return len(s.encode('utf-8')) |
0 commit comments