Skip to content

Commit 6dbd3e6

Browse files
committed
Allow search w/ user files
1 parent 9dbe12c commit 6dbd3e6

File tree

6 files changed

+170
-248
lines changed

6 files changed

+170
-248
lines changed

backend/onyx/chat/process_message.py

Lines changed: 47 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
4444
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
4545
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
46+
from onyx.chat.user_files.parse_user_files import parse_user_files
4647
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
4748
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
4849
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@@ -52,11 +53,9 @@
5253
from onyx.configs.constants import MessageType
5354
from onyx.configs.constants import MilestoneRecordType
5455
from onyx.configs.constants import NO_AUTH_USER_ID
55-
from onyx.context.search.enums import LLMEvaluationType
5656
from onyx.context.search.enums import OptionalSearchSetting
5757
from onyx.context.search.enums import QueryFlow
5858
from onyx.context.search.enums import SearchType
59-
from onyx.context.search.models import BaseFilters
6059
from onyx.context.search.models import InferenceSection
6160
from onyx.context.search.models import RetrievalDetails
6261
from onyx.context.search.retrieval.search_runner import (
@@ -95,9 +94,7 @@
9594
from onyx.file_store.models import ChatFileType
9695
from onyx.file_store.models import FileDescriptor
9796
from onyx.file_store.models import InMemoryChatFile
98-
from onyx.file_store.utils import get_user_files
9997
from onyx.file_store.utils import load_all_chat_files
100-
from onyx.file_store.utils import load_in_memory_chat_files
10198
from onyx.file_store.utils import save_files
10299
from onyx.llm.exceptions import GenAIDisabledException
103100
from onyx.llm.factory import get_llms_for_persona
@@ -312,54 +309,32 @@ def _handle_internet_search_tool_response_summary(
312309
def _get_force_search_settings(
313310
new_msg_req: CreateChatMessageRequest,
314311
tools: list[Tool],
315-
user_file_ids: list[int],
316-
user_folder_ids: list[int],
312+
search_tool_override_kwargs: SearchToolOverrideKwargs | None,
317313
) -> ForceUseTool:
318314
internet_search_available = any(
319315
isinstance(tool, InternetSearchTool) for tool in tools
320316
)
321317
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
322318

323319
if not internet_search_available and not search_tool_available:
324-
if new_msg_req.force_user_file_search:
325-
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
326-
else:
327-
# Does not matter much which tool is set here as force is false and neither tool is available
328-
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
329-
330-
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
320+
# Does not matter much which tool is set here as force is false and neither tool is available
321+
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
331322
# Currently, the internet search tool does not support query override
332323
args = (
333324
{"query": new_msg_req.query_override}
334-
if new_msg_req.query_override and tool_name == SearchTool._NAME
325+
if new_msg_req.query_override and search_tool_available
335326
else None
336327
)
337328

338-
# Create override_kwargs for the search tool if user_file_ids are provided
339-
override_kwargs = None
340-
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
341-
override_kwargs = SearchToolOverrideKwargs(
342-
force_no_rerank=False,
343-
alternate_db_session=None,
344-
retrieved_sections_callback=None,
345-
skip_query_analysis=False,
346-
user_file_ids=user_file_ids,
347-
user_folder_ids=user_folder_ids,
348-
)
349-
350-
if new_msg_req.file_descriptors:
351-
# If user has uploaded files they're using, don't run any of the search tools
352-
return ForceUseTool(force_use=False, tool_name=tool_name)
353-
354329
should_force_search = any(
355330
[
356-
new_msg_req.force_user_file_search,
357331
new_msg_req.retrieval_options
358332
and new_msg_req.retrieval_options.run_search
359333
== OptionalSearchSetting.ALWAYS,
360334
new_msg_req.search_doc_ids,
361335
new_msg_req.query_override is not None,
362336
DISABLE_LLM_CHOOSE_SEARCH,
337+
search_tool_override_kwargs is not None,
363338
]
364339
)
365340

@@ -369,13 +344,18 @@ def _get_force_search_settings(
369344

370345
return ForceUseTool(
371346
force_use=True,
372-
tool_name=tool_name,
347+
tool_name=SearchTool._NAME,
373348
args=args,
374-
override_kwargs=override_kwargs,
349+
override_kwargs=search_tool_override_kwargs,
375350
)
376351

377352
return ForceUseTool(
378-
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
353+
force_use=False,
354+
tool_name=(
355+
SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
356+
),
357+
args=args,
358+
override_kwargs=None,
379359
)
380360

381361

@@ -488,7 +468,6 @@ def _process_tool_response(
488468
retrieval_options: RetrievalDetails | None,
489469
user_file_files: list[UserFile] | None,
490470
user_files: list[InMemoryChatFile] | None,
491-
file_id_to_user_file: dict[str, InMemoryChatFile],
492471
search_for_ordering_only: bool,
493472
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
494473
level, level_question_num = (
@@ -540,7 +519,7 @@ def _process_tool_response(
540519
yield from _get_user_knowledge_files(
541520
info=info,
542521
user_files=user_files,
543-
file_id_to_user_file=file_id_to_user_file,
522+
file_id_to_user_file={file.file_id: file for file in user_files},
544523
)
545524

546525
yield info.qa_docs_response
@@ -665,8 +644,6 @@ def stream_chat_message_objects(
665644

666645
try:
667646
# Move these variables inside the try block
668-
file_id_to_user_file = {}
669-
670647
user_id = user.id if user is not None else None
671648

672649
chat_session = get_chat_session_by_id(
@@ -840,60 +817,22 @@ def stream_chat_message_objects(
840817
for folder in persona.user_folders:
841818
user_folder_ids.append(folder.id)
842819

843-
# Initialize flag for user file search
844-
use_search_for_user_files = False
845-
846-
user_files: list[InMemoryChatFile] | None = None
847-
search_for_ordering_only = False
848-
user_file_files: list[UserFile] | None = None
849-
if user_file_ids or user_folder_ids:
850-
# Load user files
851-
user_files = load_in_memory_chat_files(
852-
user_file_ids or [],
853-
user_folder_ids or [],
854-
db_session,
855-
)
856-
user_file_files = get_user_files(
857-
user_file_ids or [],
858-
user_folder_ids or [],
859-
db_session,
860-
)
861-
# Store mapping of file_id to file for later reordering
862-
if user_files:
863-
file_id_to_user_file = {file.file_id: file for file in user_files}
864-
865-
# Calculate token count for the files
866-
from onyx.db.user_documents import calculate_user_files_token_count
867-
from onyx.chat.prompt_builder.citations_prompt import (
868-
compute_max_document_tokens_for_persona,
869-
)
870-
871-
total_tokens = calculate_user_files_token_count(
872-
user_file_ids or [],
873-
user_folder_ids or [],
874-
db_session,
875-
)
876-
877-
# Calculate available tokens for documents based on prompt, user input, etc.
878-
available_tokens = compute_max_document_tokens_for_persona(
879-
db_session=db_session,
880-
persona=persona,
881-
actual_user_input=message_text, # Use the actual user message
882-
)
883-
884-
logger.debug(
885-
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
886-
)
887-
888-
# ALWAYS use search for user files, but track if we need it for context or just ordering
889-
use_search_for_user_files = True
890-
# If files are small enough for context, we'll just use search for ordering
891-
search_for_ordering_only = total_tokens <= available_tokens
892-
893-
if search_for_ordering_only:
894-
# Add original user files to context since they fit
895-
if user_files:
896-
latest_query_files.extend(user_files)
820+
# Load in user files into memory and create search tool override kwargs if needed
821+
# if we have enough tokens and no folders, we don't need to use search
822+
# we can just pass them into the prompt directly
823+
(
824+
in_memory_user_files,
825+
user_file_models,
826+
search_tool_override_kwargs_for_user_files,
827+
) = parse_user_files(
828+
user_file_ids=user_file_ids,
829+
user_folder_ids=user_folder_ids,
830+
db_session=db_session,
831+
persona=persona,
832+
actual_user_input=message_text,
833+
)
834+
if not search_tool_override_kwargs_for_user_files:
835+
latest_query_files.extend(in_memory_user_files)
897836

898837
if user_message:
899838
attach_files_to_chat_message(
@@ -1052,10 +991,13 @@ def create_response(
1052991
prompt_config=prompt_config,
1053992
db_session=db_session,
1054993
user=user,
1055-
user_knowledge_present=bool(user_files or user_folder_ids),
1056994
llm=llm,
1057995
fast_llm=fast_llm,
1058-
use_file_search=new_msg_req.force_user_file_search,
996+
run_search_setting=(
997+
retrieval_options.run_search
998+
if retrieval_options
999+
else OptionalSearchSetting.AUTO
1000+
),
10591001
search_tool_config=SearchToolConfig(
10601002
answer_style_config=answer_style_config,
10611003
document_pruning_config=document_pruning_config,
@@ -1086,128 +1028,23 @@ def create_response(
10861028
tools.extend(tool_list)
10871029

10881030
force_use_tool = _get_force_search_settings(
1089-
new_msg_req, tools, user_file_ids, user_folder_ids
1031+
new_msg_req, tools, search_tool_override_kwargs_for_user_files
10901032
)
10911033

1092-
# Set force_use if user files exceed token limit
1093-
if use_search_for_user_files:
1094-
try:
1095-
# Check if search tool is available in the tools list
1096-
search_tool_available = any(
1097-
isinstance(tool, SearchTool) for tool in tools
1098-
)
1099-
1100-
# If no search tool is available, add one
1101-
if not search_tool_available:
1102-
logger.info("No search tool available, creating one for user files")
1103-
# Create a basic search tool config
1104-
search_tool_config = SearchToolConfig(
1105-
answer_style_config=answer_style_config,
1106-
document_pruning_config=document_pruning_config,
1107-
retrieval_options=retrieval_options or RetrievalDetails(),
1108-
)
1109-
1110-
# Create and add the search tool
1111-
search_tool = SearchTool(
1112-
db_session=db_session,
1113-
user=user,
1114-
persona=persona,
1115-
retrieval_options=search_tool_config.retrieval_options,
1116-
prompt_config=prompt_config,
1117-
llm=llm,
1118-
fast_llm=fast_llm,
1119-
pruning_config=search_tool_config.document_pruning_config,
1120-
answer_style_config=search_tool_config.answer_style_config,
1121-
evaluation_type=(
1122-
LLMEvaluationType.BASIC
1123-
if persona.llm_relevance_filter
1124-
else LLMEvaluationType.SKIP
1125-
),
1126-
bypass_acl=bypass_acl,
1127-
)
1128-
1129-
# Add the search tool to the tools list
1130-
tools.append(search_tool)
1131-
1132-
logger.info(
1133-
"Added search tool for user files that exceed token limit"
1134-
)
1135-
1136-
# Now set force_use_tool.force_use to True
1137-
force_use_tool.force_use = True
1138-
force_use_tool.tool_name = SearchTool._NAME
1139-
1140-
# Set query argument if not already set
1141-
if not force_use_tool.args:
1142-
force_use_tool.args = {"query": final_msg.message}
1143-
1144-
# Pass the user file IDs to the search tool
1145-
if user_file_ids or user_folder_ids:
1146-
# Create a BaseFilters object with user_file_ids
1147-
if not retrieval_options:
1148-
retrieval_options = RetrievalDetails()
1149-
if not retrieval_options.filters:
1150-
retrieval_options.filters = BaseFilters()
1151-
1152-
# Set user file and folder IDs in the filters
1153-
retrieval_options.filters.user_file_ids = user_file_ids
1154-
retrieval_options.filters.user_folder_ids = user_folder_ids
1155-
1156-
# Create override kwargs for the search tool
1157-
1158-
override_kwargs = SearchToolOverrideKwargs(
1159-
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
1160-
alternate_db_session=None,
1161-
retrieved_sections_callback=None,
1162-
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
1163-
user_file_ids=user_file_ids,
1164-
user_folder_ids=user_folder_ids,
1165-
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
1166-
)
1167-
1168-
# Set the override kwargs in the force_use_tool
1169-
force_use_tool.override_kwargs = override_kwargs
1170-
1171-
if search_for_ordering_only:
1172-
logger.info(
1173-
"Fast path: Configured search tool with optimized settings for ordering-only"
1174-
)
1175-
logger.info(
1176-
"Fast path: Skipping reranking and query analysis for ordering-only mode"
1177-
)
1178-
logger.info(
1179-
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
1180-
)
1181-
else:
1182-
logger.info(
1183-
"Configured search tool to use ",
1184-
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
1185-
)
1186-
except Exception as e:
1187-
logger.exception(
1188-
f"Error configuring search tool for user files: {str(e)}"
1189-
)
1190-
use_search_for_user_files = False
1191-
11921034
# TODO: unify message history with single message history
11931035
message_history = [
11941036
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
11951037
]
1196-
if not use_search_for_user_files and user_files:
1038+
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
11971039
yield UserKnowledgeFilePacket(
11981040
user_files=[
11991041
FileDescriptor(
1200-
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
1042+
id=str(file.file_id), type=file.file_type, name=file.filename
12011043
)
1202-
for file in user_files
1044+
for file in in_memory_user_files
12031045
]
12041046
)
12051047

1206-
if search_for_ordering_only:
1207-
logger.info(
1208-
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
1209-
)
1210-
12111048
prompt_builder = AnswerPromptBuilder(
12121049
user_message=default_build_user_message(
12131050
user_query=final_msg.message,
@@ -1265,10 +1102,13 @@ def create_response(
12651102
selected_db_search_docs=selected_db_search_docs,
12661103
info_by_subq=info_by_subq,
12671104
retrieval_options=retrieval_options,
1268-
user_file_files=user_file_files,
1269-
user_files=user_files,
1270-
file_id_to_user_file=file_id_to_user_file,
1271-
search_for_ordering_only=search_for_ordering_only,
1105+
user_file_files=user_file_models,
1106+
user_files=in_memory_user_files,
1107+
search_for_ordering_only=(
1108+
search_tool_override_kwargs_for_user_files is not None
1109+
and search_tool_override_kwargs_for_user_files.ordering_only
1110+
is True
1111+
),
12721112
)
12731113

12741114
elif isinstance(packet, StreamStopInfo):

0 commit comments

Comments
 (0)