From cad322aafb82baa6fee5eb78eed4e7d5d9464050 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Tue, 29 Jul 2025 18:18:12 -0700 Subject: [PATCH 1/4] file processing refactor --- .../onyx/background/celery/celery_utils.py | 13 ++ backend/onyx/connectors/connector_runner.py | 18 ++ .../onyx/file_processing/file_validation.py | 8 +- backend/onyx/file_store/utils.py | 1 - backend/onyx/indexing/indexing_pipeline.py | 47 ++--- backend/onyx/server/documents/connector.py | 25 ++- .../server/query_and_chat/chat_backend.py | 182 +++++++++--------- 7 files changed, 166 insertions(+), 128 deletions(-) diff --git a/backend/onyx/background/celery/celery_utils.py b/backend/onyx/background/celery/celery_utils.py index 5c4d17b6109..840bda92011 100644 --- a/backend/onyx/background/celery/celery_utils.py +++ b/backend/onyx/background/celery/celery_utils.py @@ -8,10 +8,12 @@ from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT +from onyx.connectors.connector_runner import batched_docs from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.interfaces import BaseConnector +from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SlimConnector @@ -22,6 +24,7 @@ logger = setup_logger() +PRUNING_CHECKPOINTED_BATCH_SIZE = 32 def document_batch_to_ids( @@ -54,6 +57,16 @@ def extract_ids_from_runnable_connector( start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() doc_batch_generator = runnable_connector.poll_source(start=start, end=end) + elif isinstance(runnable_connector, CheckpointedConnector): + start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime.now(timezone.utc).timestamp() + checkpoint = runnable_connector.build_dummy_checkpoint() + checkpoint_generator = runnable_connector.load_from_checkpoint( + start=start, end=end, checkpoint=checkpoint + ) + doc_batch_generator = batched_docs( + checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE + ) else: raise RuntimeError("Pruning job could not find a valid runnable_connector.") diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index 5555a988837..e73c37300b2 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -25,6 +25,24 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint) +def batched_docs( + checkpoint_connector_generator: CheckpointOutput[CT], + batch_size: int, +) -> Generator[list[Document], None, None]: + batch: list[Document] = [] + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( + checkpoint_connector_generator + ): + if document is None: + continue + batch.append(document) + if len(batch) >= batch_size: + yield batch + batch = [] + if len(batch) > 0: + yield batch + + class CheckpointOutputWrapper(Generic[CT]): """ Wraps a CheckpointOutput generator to give things back in a more digestible format, diff --git a/backend/onyx/file_processing/file_validation.py b/backend/onyx/file_processing/file_validation.py index 34f33dd2f55..0584bcd0831 100644 --- a/backend/onyx/file_processing/file_validation.py +++ b/backend/onyx/file_processing/file_validation.py @@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool: Returns: True if the MIME type is a valid image type, False otherwise """ - if not mime_type: - return False - return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES + return ( + bool(mime_type) + and mime_type.startswith("image/") + and mime_type not in EXCLUDED_IMAGE_TYPES + ) def is_supported_by_vision_llm(mime_type: str) -> bool: diff --git a/backend/onyx/file_store/utils.py b/backend/onyx/file_store/utils.py index f4a041bf1a8..107504b33bb 100644 --- a/backend/onyx/file_store/utils.py +++ b/backend/onyx/file_store/utils.py @@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool # Get plaintext file name plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id) - # Use a separate session to avoid committing the caller's transaction try: file_store = get_default_file_store() file_content = BytesIO(plaintext_content.encode("utf-8")) diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 167f3a0c6a9..050ec9ad46d 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -868,30 +868,31 @@ def index_doc_batch( for document_id in updatable_ids: # Only calculate token counts for documents that have a user file ID if ( - document_id in doc_id_to_user_file_id - and doc_id_to_user_file_id[document_id] is not None + document_id not in doc_id_to_user_file_id + or doc_id_to_user_file_id[document_id] is None ): - user_file_id = doc_id_to_user_file_id[document_id] - if not user_file_id: - continue - document_chunks = [ - chunk - for chunk in chunks_with_embeddings - if chunk.source_document.id == document_id - ] - if document_chunks: - combined_content = " ".join( - [chunk.content for chunk in document_chunks] - ) - token_count = ( - len(llm_tokenizer.encode(combined_content)) - if llm_tokenizer - else 0 - ) - user_file_id_to_token_count[user_file_id] = token_count - user_file_id_to_raw_text[user_file_id] = combined_content - else: - user_file_id_to_token_count[user_file_id] = None + continue + + user_file_id = doc_id_to_user_file_id[document_id] + if user_file_id is None: + continue + + document_chunks = [ + chunk + for chunk in chunks_with_embeddings + if chunk.source_document.id == document_id + ] + if document_chunks: + combined_content = " ".join( + [chunk.content for chunk in document_chunks] + ) + token_count = ( + len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0 + ) + user_file_id_to_token_count[user_file_id] = token_count + user_file_id_to_raw_text[user_file_id] = combined_content + else: + user_file_id_to_token_count[user_file_id] = None # we're concerned about race conditions where multiple simultaneous indexings might result # in one set of metadata overwriting another one in vespa. diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index eb6a3ca57b8..50a933aa177 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -1,3 +1,4 @@ +import io import json import mimetypes import os @@ -101,8 +102,9 @@ from onyx.db.models import IndexingStatus from onyx.db.models import User from onyx.db.models import UserGroup__ConnectorCredentialPair -from onyx.file_processing.extract_file_text import convert_docx_to_txt +from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_store.file_store import get_default_file_store +from onyx.file_store.models import ChatFileType from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.server.documents.models import AuthStatus from onyx.server.documents.models import AuthUrl @@ -124,6 +126,7 @@ from onyx.server.documents.models import ObjectCreationIdResponse from onyx.server.documents.models import RunConnectorRequest from onyx.server.models import StatusResponse +from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type from onyx.utils.logger import setup_logger from onyx.utils.telemetry import create_milestone_and_report from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -484,15 +487,17 @@ def should_process_file(file_path: str) -> bool: deduped_file_names.append(os.path.basename(file_info)) continue - # For mypy, actual check happens at start of function - assert file.filename is not None - - # Special handling for docx files - only store the plaintext version - if file.content_type and file.content_type.startswith( - "application/vnd.openxmlformats-officedocument.wordprocessingml.document" - ): - docx_file_id = convert_docx_to_txt(file, file_store) - deduped_file_paths.append(docx_file_id) + # Special handling for doc files - only store the plaintext version + file_type = mime_type_to_chat_file_type(file.content_type) + if file_type == ChatFileType.DOC: + extracted_text = extract_file_text(file.file, file.filename or "") + text_file_id = file_store.save_file( + content=io.BytesIO(extracted_text.encode()), + display_name=file.filename, + file_origin=FileOrigin.CHAT_UPLOAD, + file_type="text/plain", + ) + deduped_file_paths.append(text_file_id) deduped_file_names.append(file.filename) continue diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index a1e60f7af22..82ddaf8318f 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -1,6 +1,5 @@ import asyncio import datetime -import io import json import os import time @@ -31,7 +30,6 @@ from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.chat_configs import HARD_DELETE_CHATS from onyx.configs.constants import DocumentSource -from onyx.configs.constants import FileOrigin from onyx.configs.constants import MessageType from onyx.configs.constants import MilestoneRecordType from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS @@ -63,9 +61,7 @@ from onyx.db.persona import get_persona_by_id from onyx.db.user_documents import create_user_files from onyx.file_processing.extract_file_text import docx_to_txt_filename -from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_store.file_store import get_default_file_store -from onyx.file_store.models import ChatFileType from onyx.file_store.models import FileDescriptor from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.factory import get_default_llms @@ -717,106 +713,110 @@ def upload_files_for_chat( ): raise HTTPException( status_code=400, - detail="File size must be less than 20MB", + detail="Images must be less than 20MB", ) - file_store = get_default_file_store() - - file_info: list[tuple[str, str | None, ChatFileType]] = [] - for file in files: - file_type = mime_type_to_chat_file_type(file.content_type) + # file_store = get_default_file_store() - file_content = file.file.read() # Read the file content + # file_info: list[tuple[str, str | None, ChatFileType]] = [] + # for file in files: + # file_type = mime_type_to_chat_file_type(file.content_type) - # NOTE: Image conversion to JPEG used to be enforced here. - # This was removed to: - # 1. Preserve original file content for downloads - # 2. Maintain transparency in formats like PNG - # 3. Ameliorate issue with file conversion - file_content_io = io.BytesIO(file_content) + # file_content = file.file.read() # Read the file content - new_content_type = file.content_type + # # NOTE: Image conversion to JPEG used to be enforced here. + # # This was removed to: + # # 1. Preserve original file content for downloads + # # 2. Maintain transparency in formats like PNG + # # 3. Ameliorate issue with file conversion + # file_content_io = io.BytesIO(file_content) - # Store the file normally - file_id = file_store.save_file( - content=file_content_io, - display_name=file.filename, - file_origin=FileOrigin.CHAT_UPLOAD, - file_type=new_content_type or file_type.value, - ) + # new_content_type = file.content_type - # 4) If the file is a doc, extract text and store that separately - if file_type == ChatFileType.DOC: - # Re-wrap bytes in a fresh BytesIO so we start at position 0 - extracted_text_io = io.BytesIO(file_content) - extracted_text = extract_file_text( - file=extracted_text_io, # use the bytes we already read - file_name=file.filename or "", - ) + # # Store the file normally + # file_id = file_store.save_file( + # content=file_content_io, + # display_name=file.filename, + # file_origin=FileOrigin.CHAT_UPLOAD, + # file_type=new_content_type or file_type.value, + # ) - text_file_id = file_store.save_file( - content=io.BytesIO(extracted_text.encode()), - display_name=file.filename, - file_origin=FileOrigin.CHAT_UPLOAD, - file_type="text/plain", - ) - # Return the text file as the "main" file descriptor for doc types - file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT)) - else: - file_info.append((file_id, file.filename, file_type)) - - # 5) Create a user file for each uploaded file - user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session) - for user_file in user_files: - # 6) Create connector - connector_base = ConnectorBase( - name=f"UserFile-{int(time.time())}", - source=DocumentSource.FILE, - input_type=InputType.LOAD_STATE, - connector_specific_config={ - "file_locations": [user_file.file_id], - "file_names": [user_file.name], - "zip_metadata": {}, - }, - refresh_freq=None, - prune_freq=None, - indexing_start=None, - ) - connector = create_connector( - db_session=db_session, - connector_data=connector_base, - ) + # # 4) If the file is a doc, extract text and store that separately + # if file_type == ChatFileType.DOC: + # # Re-wrap bytes in a fresh BytesIO so we start at position 0 + # extracted_text_io = io.BytesIO(file_content) + # extracted_text = extract_file_text( + # file=extracted_text_io, # use the bytes we already read + # file_name=file.filename or "", + # ) + + # text_file_id = file_store.save_file( + # content=io.BytesIO(extracted_text.encode()), + # display_name=file.filename, + # file_origin=FileOrigin.CHAT_UPLOAD, + # file_type="text/plain", + # ) + # # Return the text file as the "main" file descriptor for doc types + # file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT)) + # else: + # file_info.append((file_id, file.filename, file_type)) + + # 5) Create a user file for each uploaded file + user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session) + for user_file in user_files: + # 6) Create connector + connector_base = ConnectorBase( + name=f"UserFile-{int(time.time())}", + source=DocumentSource.FILE, + input_type=InputType.LOAD_STATE, + connector_specific_config={ + "file_locations": [user_file.file_id], + "file_names": [user_file.name], + "zip_metadata": {}, + }, + refresh_freq=None, + prune_freq=None, + indexing_start=None, + ) + connector = create_connector( + db_session=db_session, + connector_data=connector_base, + ) - # 7) Create credential - credential_info = CredentialBase( - credential_json={}, - admin_public=True, - source=DocumentSource.FILE, - curator_public=True, - groups=[], - name=f"UserFileCredential-{int(time.time())}", - is_user_file=True, - ) - credential = create_credential(credential_info, user, db_session) + # 7) Create credential + credential_info = CredentialBase( + credential_json={}, + admin_public=True, + source=DocumentSource.FILE, + curator_public=True, + groups=[], + name=f"UserFileCredential-{int(time.time())}", + is_user_file=True, + ) + credential = create_credential(credential_info, user, db_session) - # 8) Create connector credential pair - cc_pair = add_credential_to_connector( - db_session=db_session, - user=user, - connector_id=connector.id, - credential_id=credential.id, - cc_pair_name=f"UserFileCCPair-{int(time.time())}", - access_type=AccessType.PRIVATE, - auto_sync_options=None, - groups=[], - ) - user_file.cc_pair_id = cc_pair.data - db_session.commit() + # 8) Create connector credential pair + cc_pair = add_credential_to_connector( + db_session=db_session, + user=user, + connector_id=connector.id, + credential_id=credential.id, + cc_pair_name=f"UserFileCCPair-{int(time.time())}", + access_type=AccessType.PRIVATE, + auto_sync_options=None, + groups=[], + ) + user_file.cc_pair_id = cc_pair.data + db_session.commit() return { "files": [ - {"id": file_id, "type": file_type, "name": file_name} - for file_id, file_name, file_type in file_info + { + "id": user_file.file_id, + "type": mime_type_to_chat_file_type(user_file.content_type), + "name": user_file.name, + } + for user_file in user_files ] } From de7ef73050290e7c5d68ed9b16f859a3c3a4552a Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 4 Aug 2025 21:35:00 -0700 Subject: [PATCH 2/4] mypy --- backend/onyx/server/documents/connector.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 50a933aa177..5a7e3be30d7 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -487,6 +487,9 @@ def should_process_file(file_path: str) -> bool: deduped_file_names.append(os.path.basename(file_info)) continue + # For mypy, actual check happens at start of function + assert file.filename is not None + # Special handling for doc files - only store the plaintext version file_type = mime_type_to_chat_file_type(file.content_type) if file_type == ChatFileType.DOC: From ab4a917e7a7f6a2a9c086de9480a55d445cf5065 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 6 Aug 2025 20:15:16 -0700 Subject: [PATCH 3/4] CW comments --- .../reporting/usage_export_generation.py | 4 +- backend/onyx/indexing/indexing_pipeline.py | 7 +-- backend/onyx/server/documents/connector.py | 8 ++-- .../server/query_and_chat/chat_backend.py | 45 ------------------- 4 files changed, 8 insertions(+), 56 deletions(-) diff --git a/backend/ee/onyx/server/reporting/usage_export_generation.py b/backend/ee/onyx/server/reporting/usage_export_generation.py index 97ec2d03c3c..2391dbc226a 100644 --- a/backend/ee/onyx/server/reporting/usage_export_generation.py +++ b/backend/ee/onyx/server/reporting/usage_export_generation.py @@ -67,7 +67,7 @@ def generate_chat_messages_report( file_id = file_store.save_file( content=temp_file, display_name=file_name, - file_origin=FileOrigin.OTHER, + file_origin=FileOrigin.GENERATED_REPORT, file_type="text/csv", ) @@ -99,7 +99,7 @@ def generate_user_report( file_id = file_store.save_file( content=temp_file, display_name=file_name, - file_origin=FileOrigin.OTHER, + file_origin=FileOrigin.GENERATED_REPORT, file_type="text/csv", ) diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 050ec9ad46d..29592b7d3e2 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -867,13 +867,8 @@ def index_doc_batch( user_file_id_to_raw_text: dict[int, str] = {} for document_id in updatable_ids: # Only calculate token counts for documents that have a user file ID - if ( - document_id not in doc_id_to_user_file_id - or doc_id_to_user_file_id[document_id] is None - ): - continue - user_file_id = doc_id_to_user_file_id[document_id] + user_file_id = doc_id_to_user_file_id.get(document_id) if user_file_id is None: continue diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 5a7e3be30d7..cb1bc0c6cd0 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -441,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool: ) -def upload_files(files: list[UploadFile]) -> FileUploadResponse: +def upload_files( + files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR +) -> FileUploadResponse: for file in files: if not file.filename: raise HTTPException(status_code=400, detail="File name cannot be empty") @@ -497,7 +499,7 @@ def should_process_file(file_path: str) -> bool: text_file_id = file_store.save_file( content=io.BytesIO(extracted_text.encode()), display_name=file.filename, - file_origin=FileOrigin.CHAT_UPLOAD, + file_origin=file_origin, file_type="text/plain", ) deduped_file_paths.append(text_file_id) @@ -528,7 +530,7 @@ def upload_files_api( files: list[UploadFile], _: User = Depends(current_curator_or_admin_user), ) -> FileUploadResponse: - return upload_files(files) + return upload_files(files, FileOrigin.OTHER) @router.get("/admin/connector") diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index 82ddaf8318f..f95e49d21d4 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -716,51 +716,6 @@ def upload_files_for_chat( detail="Images must be less than 20MB", ) - # file_store = get_default_file_store() - - # file_info: list[tuple[str, str | None, ChatFileType]] = [] - # for file in files: - # file_type = mime_type_to_chat_file_type(file.content_type) - - # file_content = file.file.read() # Read the file content - - # # NOTE: Image conversion to JPEG used to be enforced here. - # # This was removed to: - # # 1. Preserve original file content for downloads - # # 2. Maintain transparency in formats like PNG - # # 3. Ameliorate issue with file conversion - # file_content_io = io.BytesIO(file_content) - - # new_content_type = file.content_type - - # # Store the file normally - # file_id = file_store.save_file( - # content=file_content_io, - # display_name=file.filename, - # file_origin=FileOrigin.CHAT_UPLOAD, - # file_type=new_content_type or file_type.value, - # ) - - # # 4) If the file is a doc, extract text and store that separately - # if file_type == ChatFileType.DOC: - # # Re-wrap bytes in a fresh BytesIO so we start at position 0 - # extracted_text_io = io.BytesIO(file_content) - # extracted_text = extract_file_text( - # file=extracted_text_io, # use the bytes we already read - # file_name=file.filename or "", - # ) - - # text_file_id = file_store.save_file( - # content=io.BytesIO(extracted_text.encode()), - # display_name=file.filename, - # file_origin=FileOrigin.CHAT_UPLOAD, - # file_type="text/plain", - # ) - # # Return the text file as the "main" file descriptor for doc types - # file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT)) - # else: - # file_info.append((file_id, file.filename, file_type)) - # 5) Create a user file for each uploaded file user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session) for user_file in user_files: From 6f53c436aeb91b5254e9f7cbcec05c5d7f0fe7dc Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 7 Aug 2025 15:45:08 -0700 Subject: [PATCH 4/4] address CW --- .../onyx/background/celery/celery_utils.py | 36 ++++++++++++------- backend/onyx/connectors/connector_runner.py | 18 ++++++---- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/backend/onyx/background/celery/celery_utils.py b/backend/onyx/background/celery/celery_utils.py index 840bda92011..0e7a36ac4cd 100644 --- a/backend/onyx/background/celery/celery_utils.py +++ b/backend/onyx/background/celery/celery_utils.py @@ -1,3 +1,5 @@ +from collections.abc import Generator +from collections.abc import Iterator from datetime import datetime from datetime import timezone from pathlib import Path @@ -8,7 +10,7 @@ from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT -from onyx.connectors.connector_runner import batched_docs +from onyx.connectors.connector_runner import batched_doc_ids from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) @@ -28,9 +30,10 @@ def document_batch_to_ids( - doc_batch: list[Document], -) -> set[str]: - return {doc.id for doc in doc_batch} + doc_batch: Iterator[list[Document]], +) -> Generator[set[str], None, None]: + for doc_list in doc_batch: + yield {doc.id for doc in doc_list} def extract_ids_from_runnable_connector( @@ -49,14 +52,18 @@ def extract_ids_from_runnable_connector( for metadata_batch in runnable_connector.retrieve_all_slim_documents(): all_connector_doc_ids.update({doc.id for doc in metadata_batch}) - doc_batch_generator = None + doc_batch_id_generator = None if isinstance(runnable_connector, LoadConnector): - doc_batch_generator = runnable_connector.load_from_state() + doc_batch_id_generator = document_batch_to_ids( + runnable_connector.load_from_state() + ) elif isinstance(runnable_connector, PollConnector): start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() - doc_batch_generator = runnable_connector.poll_source(start=start, end=end) + doc_batch_id_generator = document_batch_to_ids( + runnable_connector.poll_source(start=start, end=end) + ) elif isinstance(runnable_connector, CheckpointedConnector): start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() @@ -64,28 +71,31 @@ def extract_ids_from_runnable_connector( checkpoint_generator = runnable_connector.load_from_checkpoint( start=start, end=end, checkpoint=checkpoint ) - doc_batch_generator = batched_docs( + doc_batch_id_generator = batched_doc_ids( checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE ) else: raise RuntimeError("Pruning job could not find a valid runnable_connector.") - doc_batch_processing_func = document_batch_to_ids + # this function is called per batch for rate limiting + def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]: + return doc_batch_ids + if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE: doc_batch_processing_func = rate_limit_builder( max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 - )(document_batch_to_ids) - for doc_batch in doc_batch_generator: + )(lambda x: x) + for doc_batch_ids in doc_batch_id_generator: if callback: if callback.should_stop(): raise RuntimeError( "extract_ids_from_runnable_connector: Stop signal detected" ) - all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) + all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids)) if callback: - callback.progress("extract_ids_from_runnable_connector", len(doc_batch)) + callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids)) return all_connector_doc_ids diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index e73c37300b2..b915ecafcd5 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -25,20 +25,24 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint) -def batched_docs( +def batched_doc_ids( checkpoint_connector_generator: CheckpointOutput[CT], batch_size: int, -) -> Generator[list[Document], None, None]: - batch: list[Document] = [] +) -> Generator[set[str], None, None]: + batch: set[str] = set() for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( checkpoint_connector_generator ): - if document is None: - continue - batch.append(document) + if document is not None: + batch.add(document.id) + elif ( + failure and failure.failed_document and failure.failed_document.document_id + ): + batch.add(failure.failed_document.document_id) + if len(batch) >= batch_size: yield batch - batch = [] + batch = set() if len(batch) > 0: yield batch