Skip to content

Commit 1a6ff9e

Browse files
committed
file processing refactor
1 parent 146628e commit 1a6ff9e

File tree

7 files changed

+166
-128
lines changed

7 files changed

+166
-128
lines changed

backend/onyx/background/celery/celery_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
1010
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
11+
from onyx.connectors.connector_runner import batched_docs
1112
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
1213
rate_limit_builder,
1314
)
1415
from onyx.connectors.interfaces import BaseConnector
16+
from onyx.connectors.interfaces import CheckpointedConnector
1517
from onyx.connectors.interfaces import LoadConnector
1618
from onyx.connectors.interfaces import PollConnector
1719
from onyx.connectors.interfaces import SlimConnector
@@ -22,6 +24,7 @@
2224

2325

2426
logger = setup_logger()
27+
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
2528

2629

2730
def document_batch_to_ids(
@@ -54,6 +57,16 @@ def extract_ids_from_runnable_connector(
5457
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
5558
end = datetime.now(timezone.utc).timestamp()
5659
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
60+
elif isinstance(runnable_connector, CheckpointedConnector):
61+
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
62+
end = datetime.now(timezone.utc).timestamp()
63+
checkpoint = runnable_connector.build_dummy_checkpoint()
64+
checkpoint_generator = runnable_connector.load_from_checkpoint(
65+
start=start, end=end, checkpoint=checkpoint
66+
)
67+
doc_batch_generator = batched_docs(
68+
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
69+
)
5770
else:
5871
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
5972

backend/onyx/connectors/connector_runner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@
2525
CT = TypeVar("CT", bound=ConnectorCheckpoint)
2626

2727

28+
def batched_docs(
29+
checkpoint_connector_generator: CheckpointOutput[CT],
30+
batch_size: int,
31+
) -> Generator[list[Document], None, None]:
32+
batch: list[Document] = []
33+
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
34+
checkpoint_connector_generator
35+
):
36+
if document is None:
37+
continue
38+
batch.append(document)
39+
if len(batch) >= batch_size:
40+
yield batch
41+
batch = []
42+
if len(batch) > 0:
43+
yield batch
44+
45+
2846
class CheckpointOutputWrapper(Generic[CT]):
2947
"""
3048
Wraps a CheckpointOutput generator to give things back in a more digestible format,

backend/onyx/file_processing/file_validation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
3232
Returns:
3333
True if the MIME type is a valid image type, False otherwise
3434
"""
35-
if not mime_type:
36-
return False
37-
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
35+
return (
36+
bool(mime_type)
37+
and mime_type.startswith("image/")
38+
and mime_type not in EXCLUDED_IMAGE_TYPES
39+
)
3840

3941

4042
def is_supported_by_vision_llm(mime_type: str) -> bool:

backend/onyx/file_store/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
4646
# Get plaintext file name
4747
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
4848

49-
# Use a separate session to avoid committing the caller's transaction
5049
try:
5150
file_store = get_default_file_store()
5251
file_content = BytesIO(plaintext_content.encode("utf-8"))

backend/onyx/indexing/indexing_pipeline.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -868,30 +868,31 @@ def index_doc_batch(
868868
for document_id in updatable_ids:
869869
# Only calculate token counts for documents that have a user file ID
870870
if (
871-
document_id in doc_id_to_user_file_id
872-
and doc_id_to_user_file_id[document_id] is not None
871+
document_id not in doc_id_to_user_file_id
872+
or doc_id_to_user_file_id[document_id] is None
873873
):
874-
user_file_id = doc_id_to_user_file_id[document_id]
875-
if not user_file_id:
876-
continue
877-
document_chunks = [
878-
chunk
879-
for chunk in chunks_with_embeddings
880-
if chunk.source_document.id == document_id
881-
]
882-
if document_chunks:
883-
combined_content = " ".join(
884-
[chunk.content for chunk in document_chunks]
885-
)
886-
token_count = (
887-
len(llm_tokenizer.encode(combined_content))
888-
if llm_tokenizer
889-
else 0
890-
)
891-
user_file_id_to_token_count[user_file_id] = token_count
892-
user_file_id_to_raw_text[user_file_id] = combined_content
893-
else:
894-
user_file_id_to_token_count[user_file_id] = None
874+
continue
875+
876+
user_file_id = doc_id_to_user_file_id[document_id]
877+
if user_file_id is None:
878+
continue
879+
880+
document_chunks = [
881+
chunk
882+
for chunk in chunks_with_embeddings
883+
if chunk.source_document.id == document_id
884+
]
885+
if document_chunks:
886+
combined_content = " ".join(
887+
[chunk.content for chunk in document_chunks]
888+
)
889+
token_count = (
890+
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
891+
)
892+
user_file_id_to_token_count[user_file_id] = token_count
893+
user_file_id_to_raw_text[user_file_id] = combined_content
894+
else:
895+
user_file_id_to_token_count[user_file_id] = None
895896

896897
# we're concerned about race conditions where multiple simultaneous indexings might result
897898
# in one set of metadata overwriting another one in vespa.

backend/onyx/server/documents/connector.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import json
23
import mimetypes
34
import os
@@ -101,8 +102,9 @@
101102
from onyx.db.models import IndexingStatus
102103
from onyx.db.models import User
103104
from onyx.db.models import UserGroup__ConnectorCredentialPair
104-
from onyx.file_processing.extract_file_text import convert_docx_to_txt
105+
from onyx.file_processing.extract_file_text import extract_file_text
105106
from onyx.file_store.file_store import get_default_file_store
107+
from onyx.file_store.models import ChatFileType
106108
from onyx.key_value_store.interface import KvKeyNotFoundError
107109
from onyx.server.documents.models import AuthStatus
108110
from onyx.server.documents.models import AuthUrl
@@ -124,6 +126,7 @@
124126
from onyx.server.documents.models import ObjectCreationIdResponse
125127
from onyx.server.documents.models import RunConnectorRequest
126128
from onyx.server.models import StatusResponse
129+
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
127130
from onyx.utils.logger import setup_logger
128131
from onyx.utils.telemetry import create_milestone_and_report
129132
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -484,15 +487,17 @@ def should_process_file(file_path: str) -> bool:
484487
deduped_file_names.append(os.path.basename(file_info))
485488
continue
486489

487-
# For mypy, actual check happens at start of function
488-
assert file.filename is not None
489-
490-
# Special handling for docx files - only store the plaintext version
491-
if file.content_type and file.content_type.startswith(
492-
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
493-
):
494-
docx_file_id = convert_docx_to_txt(file, file_store)
495-
deduped_file_paths.append(docx_file_id)
490+
# Special handling for doc files - only store the plaintext version
491+
file_type = mime_type_to_chat_file_type(file.content_type)
492+
if file_type == ChatFileType.DOC:
493+
extracted_text = extract_file_text(file.file, file.filename or "")
494+
text_file_id = file_store.save_file(
495+
content=io.BytesIO(extracted_text.encode()),
496+
display_name=file.filename,
497+
file_origin=FileOrigin.CHAT_UPLOAD,
498+
file_type="text/plain",
499+
)
500+
deduped_file_paths.append(text_file_id)
496501
deduped_file_names.append(file.filename)
497502
continue
498503

0 commit comments

Comments
 (0)