Skip to content

Commit 6a76788

Browse files
committed
.
1 parent 2d4273e commit 6a76788

File tree

4 files changed

+35
-24
lines changed

4 files changed

+35
-24
lines changed

backend/model_server/custom_models.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import cast
2-
from typing import Optional
3-
from typing import TYPE_CHECKING
42

53
import numpy as np
64
import torch
75
import torch.nn.functional as F
86
from fastapi import APIRouter
97
from huggingface_hub import snapshot_download # type: ignore
8+
from setfit import SetFitModel # type: ignore[import]
109
from transformers import AutoTokenizer # type: ignore
1110
from transformers import BatchEncoding # type: ignore
1211
from transformers import PreTrainedTokenizer # type: ignore
@@ -38,9 +37,6 @@
3837
from shared_configs.model_server_models import IntentRequest
3938
from shared_configs.model_server_models import IntentResponse
4039

41-
if TYPE_CHECKING:
42-
from setfit import SetFitModel # type: ignore[import-untyped]
43-
4440
logger = setup_logger()
4541

4642
router = APIRouter(prefix="/custom")
@@ -51,8 +47,7 @@
5147
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
5248
_INTENT_MODEL: HybridClassifier | None = None
5349

54-
55-
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
50+
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
5651

5752
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
5853

@@ -146,10 +141,8 @@ def get_local_intent_model(
146141
def get_local_information_content_model(
147142
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
148143
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
149-
) -> "SetFitModel":
144+
) -> SetFitModel:
150145
global _INFORMATION_CONTENT_MODEL
151-
from setfit import SetFitModel # type: ignore
152-
153146
if _INFORMATION_CONTENT_MODEL is None:
154147
try:
155148
# Calculate where the cache should be, then load from local if available

backend/model_server/encoders.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from fastapi import HTTPException
88
from fastapi import Request
99
from litellm.exceptions import RateLimitError
10-
from sentence_transformers import CrossEncoder
11-
from sentence_transformers import SentenceTransformer
10+
from sentence_transformers import CrossEncoder # type: ignore
11+
from sentence_transformers import SentenceTransformer # type: ignore
1212

1313
from model_server.utils import simple_log_function_time
1414
from onyx.utils.logger import setup_logger
@@ -25,8 +25,8 @@
2525
router = APIRouter(prefix="/encoder")
2626

2727

28-
_GLOBAL_MODELS_DICT: dict[str, SentenceTransformer] = {}
29-
_RERANK_MODEL: Optional[CrossEncoder] = None
28+
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
29+
_RERANK_MODEL: Optional["CrossEncoder"] = None
3030

3131
# If we are not only indexing, dont want retry very long
3232
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
@@ -36,13 +36,14 @@
3636
def get_embedding_model(
3737
model_name: str,
3838
max_context_length: int,
39-
) -> SentenceTransformer:
39+
) -> "SentenceTransformer":
4040
"""
4141
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
4242
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
4343
"""
44+
from sentence_transformers import SentenceTransformer # type: ignore
4445

45-
def _prewarm_rope(st_model: SentenceTransformer, target_len: int) -> None:
46+
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
4647
"""
4748
Build RoPE cos/sin caches once on the final device/dtype so later forwards only read.
4849
Works by calling the underlying HF model directly with dummy IDs/attention.
@@ -101,7 +102,7 @@ def get_local_reranking_model(
101102

102103

103104
def _concurrent_embedding(
104-
texts: list[str], model: SentenceTransformer, normalize_embeddings: bool
105+
texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool
105106
) -> Any:
106107
"""Synchronous wrapper for concurrent_embedding to use with run_in_executor."""
107108
for _ in range(ENCODING_RETRIES):

backend/onyx/file_processing/extract_file_text.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
from typing import Any
1616
from typing import IO
1717
from typing import NamedTuple
18+
from typing import Optional
19+
from typing import TYPE_CHECKING
1820
from zipfile import BadZipFile
1921

2022
import chardet
2123
import openpyxl
22-
from markitdown import FileConversionException
23-
from markitdown import MarkItDown
24-
from markitdown import StreamInfo
25-
from markitdown import UnsupportedFormatException
2624
from PIL import Image
2725
from pypdf import PdfReader
2826
from pypdf.errors import PdfStreamError
@@ -37,6 +35,11 @@
3735
from onyx.utils.file_types import WORD_PROCESSING_MIME_TYPE
3836
from onyx.utils.logger import setup_logger
3937

38+
39+
if TYPE_CHECKING:
40+
from markitdown import MarkItDown
41+
42+
4043
logger = setup_logger()
4144

4245
# NOTE(rkuo): Unify this with upload_files_for_chat and file_valiation.py
@@ -85,17 +88,19 @@
8588
"image/webp",
8689
]
8790

88-
_MARKITDOWN_CONVERTER: MarkItDown | None = None
91+
_MARKITDOWN_CONVERTER: Optional["MarkItDown"] = None
8992

9093
KNOWN_OPENPYXL_BUGS = [
9194
"Value must be either numerical or a string containing a wildcard",
9295
"File contains no valid workbook part",
9396
]
9497

9598

96-
def get_markitdown_converter() -> MarkItDown:
99+
def get_markitdown_converter() -> "MarkItDown":
97100
global _MARKITDOWN_CONVERTER
98101
if _MARKITDOWN_CONVERTER is None:
102+
from markitdown import MarkItDown
103+
99104
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
100105
return _MARKITDOWN_CONVERTER
101106

@@ -357,6 +362,12 @@ def docx_to_text_and_images(
357362
of avoiding materializing the list of images in memory.
358363
The images list returned is empty in this case.
359364
"""
365+
from markitdown import (
366+
FileConversionException,
367+
StreamInfo,
368+
UnsupportedFormatException,
369+
)
370+
360371
md = get_markitdown_converter()
361372
try:
362373
doc = md.convert(
@@ -393,6 +404,12 @@ def docx_to_text_and_images(
393404

394405

395406
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
407+
from markitdown import (
408+
FileConversionException,
409+
StreamInfo,
410+
UnsupportedFormatException,
411+
)
412+
396413
md = get_markitdown_converter()
397414
stream_info = StreamInfo(
398415
mimetype=PRESENTATION_MIME_TYPE, filename=file_name or None, extension=".pptx"

backend/scripts/check_lazy_imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
logger = logging.getLogger(__name__)
1717

18-
_MODULES_TO_LAZY_IMPORT = {"vertexai"}
18+
_MODULES_TO_LAZY_IMPORT = {"vertexai", "markitdown"}
1919

2020

2121
@dataclass

0 commit comments

Comments
 (0)