Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/ee/onyx/server/reporting/usage_export_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_chat_messages_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)

Expand Down Expand Up @@ -99,7 +99,7 @@ def generate_user_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)

Expand Down
45 changes: 34 additions & 11 deletions backend/onyx/background/celery/celery_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
Expand All @@ -8,10 +10,12 @@

from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
Expand All @@ -22,12 +26,14 @@


logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32


def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}


def extract_ids_from_runnable_connector(
Expand All @@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})

doc_batch_generator = None
doc_batch_id_generator = None

if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")

doc_batch_processing_func = document_batch_to_ids
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids

if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
)(lambda x: x)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)

all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))

if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))

return all_connector_doc_ids

Expand Down
22 changes: 22 additions & 0 deletions backend/onyx/connectors/connector_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@
CT = TypeVar("CT", bound=ConnectorCheckpoint)


def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)

if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch


class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,
Expand Down
8 changes: 5 additions & 3 deletions backend/onyx/file_processing/file_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
if not mime_type:
return False
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)


def is_supported_by_vision_llm(mime_type: str) -> bool:
Expand Down
1 change: 0 additions & 1 deletion backend/onyx/file_store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)

# Use a separate session to avoid committing the caller's transaction
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
Expand Down
46 changes: 21 additions & 25 deletions backend/onyx/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,31 +867,27 @@ def index_doc_batch(
user_file_id_to_raw_text: dict[int, str] = {}
for document_id in updatable_ids:
# Only calculate token counts for documents that have a user file ID
if (
document_id in doc_id_to_user_file_id
and doc_id_to_user_file_id[document_id] is not None
):
user_file_id = doc_id_to_user_file_id[document_id]
if not user_file_id:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content))
if llm_tokenizer
else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None

user_file_id = doc_id_to_user_file_id.get(document_id)
if user_file_id is None:
continue

document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None

# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
Expand Down
28 changes: 19 additions & 9 deletions backend/onyx/server/documents/connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import mimetypes
import os
Expand Down Expand Up @@ -101,8 +102,9 @@
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
Expand All @@ -124,6 +126,7 @@
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
Expand Down Expand Up @@ -438,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool:
)


def upload_files(files: list[UploadFile]) -> FileUploadResponse:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
Expand Down Expand Up @@ -487,12 +492,17 @@ def should_process_file(file_path: str) -> bool:
# For mypy, actual check happens at start of function
assert file.filename is not None

# Special handling for docx files - only store the plaintext version
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
docx_file_id = convert_docx_to_txt(file, file_store)
deduped_file_paths.append(docx_file_id)
# Special handling for doc files - only store the plaintext version
file_type = mime_type_to_chat_file_type(file.content_type)
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file.file, file.filename or "")
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=file_origin,
file_type="text/plain",
)
deduped_file_paths.append(text_file_id)
deduped_file_names.append(file.filename)
continue

Expand Down Expand Up @@ -520,7 +530,7 @@ def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files)
return upload_files(files, FileOrigin.OTHER)


@router.get("/admin/connector")
Expand Down
Loading
Loading