From 5a0c6d003607dfb9a7445a6a87df9a6062b73bc6 Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 2 Oct 2024 17:50:54 -0700 Subject: [PATCH 01/10] Fix Fix Refactor more more fix refactor Fix circular imports Refactor Move tests around --- backend/danswer/chat/models.py | 2 +- backend/danswer/chat/process_message.py | 86 ++- backend/danswer/llm/answering/answer.py | 588 ++++-------------- .../llm/answering/llm_response_handler.py | 83 +++ .../danswer/llm/answering/prompts/build.py | 72 +-- .../llm/answering/prompts/citations_prompt.py | 14 +- .../llm/answering/prompts/quotes_prompt.py | 25 +- .../danswer/llm/answering/prune_and_merge.py | 2 +- .../stream_processing/citation_processing.py | 232 +++---- .../citation_response_handler.py | 61 ++ .../answering/tool/tool_response_handler.py | 205 ++++++ backend/danswer/llm/utils.py | 22 + .../one_shot_answer/answer_question.py | 37 +- .../danswer/server/features/persona/models.py | 2 +- backend/danswer/server/features/tool/api.py | 14 +- backend/danswer/tools/base_tool.py | 59 ++ backend/danswer/tools/built_in_tools.py | 10 +- .../custom/custom_tool_prompt_builder.py | 21 - backend/danswer/tools/tool.py | 27 +- .../custom/base_tool_types.py | 0 .../custom/custom_tool.py | 42 +- .../custom/custom_tool_prompts.py | 0 .../custom/openapi_parsing.py | 0 .../images/image_generation_tool.py | 38 +- .../images/prompt.py | 0 .../internet_search/internet_search_tool.py | 48 +- .../internet_search/models.py | 0 .../search/search_tool.py | 51 +- .../search/search_utils.py | 0 .../search_like_tool_utils.py | 71 +++ backend/danswer/tools/tool_runner.py | 2 +- .../ee/danswer/server/query_and_chat/utils.py | 2 +- .../tests/dev_apis/test_simple_chat_api.py | 3 + .../unit/danswer/llm/answering/conftest.py | 113 ++++ .../test_citation_processing.py | 18 +- .../unit/danswer/llm/answering/test_answer.py | 422 +++++++++++++ .../danswer/llm/answering/test_skip_gen_ai.py | 35 +- .../danswer/tools/custom/test_custom_tools.py | 24 +- 38 files changed, 1634 insertions(+), 797 deletions(-) create mode 100644 backend/danswer/llm/answering/llm_response_handler.py create mode 100644 backend/danswer/llm/answering/stream_processing/citation_response_handler.py create mode 100644 backend/danswer/llm/answering/tool/tool_response_handler.py create mode 100644 backend/danswer/tools/base_tool.py delete mode 100644 backend/danswer/tools/custom/custom_tool_prompt_builder.py rename backend/danswer/tools/{ => tool_implementations}/custom/base_tool_types.py (100%) rename backend/danswer/tools/{ => tool_implementations}/custom/custom_tool.py (88%) rename backend/danswer/tools/{ => tool_implementations}/custom/custom_tool_prompts.py (100%) rename backend/danswer/tools/{ => tool_implementations}/custom/openapi_parsing.py (100%) rename backend/danswer/tools/{ => tool_implementations}/images/image_generation_tool.py (86%) rename backend/danswer/tools/{ => tool_implementations}/images/prompt.py (100%) rename backend/danswer/tools/{ => tool_implementations}/internet_search/internet_search_tool.py (81%) rename backend/danswer/tools/{ => tool_implementations}/internet_search/models.py (100%) rename backend/danswer/tools/{ => tool_implementations}/search/search_tool.py (87%) rename backend/danswer/tools/{ => tool_implementations}/search/search_utils.py (100%) create mode 100644 backend/danswer/tools/tool_implementations/search_like_tool_utils.py create mode 100644 backend/tests/unit/danswer/llm/answering/conftest.py create mode 100644 backend/tests/unit/danswer/llm/answering/test_answer.py diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 97d5b9e7275..d5925fc2ed9 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -10,7 +10,7 @@ from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse -from danswer.tools.custom.base_tool_types import ToolResultType +from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType class LlmDoc(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index f58a34c3243..4ff30dd3c04 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -77,31 +77,49 @@ from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line from danswer.tools.built_in_tools import get_built_in_tool_by_id -from danswer.tools.custom.custom_tool import ( +from danswer.tools.force import ForceUseTool +from danswer.tools.models import DynamicSchemaInfo +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool +from danswer.tools.tool_implementations.custom.custom_tool import ( build_custom_tools_from_openapi_schema_and_headers, ) -from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID -from danswer.tools.custom.custom_tool import CustomToolCallSummary -from danswer.tools.force import ForceUseTool -from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID -from danswer.tools.images.image_generation_tool import ImageGenerationResponse -from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.internet_search.internet_search_tool import ( +from danswer.tools.tool_implementations.custom.custom_tool import ( + CUSTOM_TOOL_RESPONSE_ID, +) +from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary +from danswer.tools.tool_implementations.images.image_generation_tool import ( + IMAGE_GENERATION_RESPONSE_ID, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationResponse, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( INTERNET_SEARCH_RESPONSE_ID, ) -from danswer.tools.internet_search.internet_search_tool import ( +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( internet_search_response_to_search_docs, ) -from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse -from danswer.tools.internet_search.internet_search_tool import InternetSearchTool -from danswer.tools.models import DynamicSchemaInfo -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID -from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID -from danswer.tools.search.search_tool import SearchResponseSummary -from danswer.tools.search.search_tool import SearchTool -from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID -from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchResponse, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchTool, +) +from danswer.tools.tool_implementations.search.search_tool import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) +from danswer.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.tool_implementations.search.search_tool import ( + SECTION_RELEVANCE_LIST_ID, +) from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.utils import compute_all_tool_tokens from danswer.tools.utils import explicit_tool_calling_supported @@ -532,6 +550,13 @@ def stream_chat_message_objects( if not persona else PromptConfig.from_model(persona.prompts[0]) ) + answer_style_config = AnswerStyleConfig( + citation_config=CitationConfig( + all_docs_useful=selected_db_search_docs is not None + ), + document_pruning_config=document_pruning_config, + structured_response_format=new_msg_req.structured_response_format, + ) # find out what tools to use search_tool: SearchTool | None = None @@ -550,13 +575,16 @@ def stream_chat_message_objects( llm=llm, fast_llm=fast_llm, pruning_config=document_pruning_config, + answer_style_config=answer_style_config, selected_sections=selected_sections, chunks_above=new_msg_req.chunks_above, chunks_below=new_msg_req.chunks_below, full_doc=new_msg_req.full_doc, - evaluation_type=LLMEvaluationType.BASIC - if persona.llm_relevance_filter - else LLMEvaluationType.SKIP, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), ) tool_dict[db_tool_model.id] = [search_tool] elif tool_cls.__name__ == ImageGenerationTool.__name__: @@ -626,7 +654,11 @@ def stream_chat_message_objects( "Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!" ) tool_dict[db_tool_model.id] = [ - InternetSearchTool(api_key=bing_api_key) + InternetSearchTool( + api_key=bing_api_key, + answer_style_config=answer_style_config, + prompt_config=prompt_config, + ) ] continue @@ -667,13 +699,7 @@ def stream_chat_message_objects( is_connected=is_connected, question=final_msg.message, latest_query_files=latest_query_files, - answer_style_config=AnswerStyleConfig( - citation_config=CitationConfig( - all_docs_useful=selected_db_search_docs is not None - ), - document_pruning_config=document_pruning_config, - structured_response_format=new_msg_req.structured_response_format, - ), + answer_style_config=answer_style_config, prompt_config=prompt_config, llm=( llm diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index d2aeb1b14c4..0aea52c303b 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,72 +1,38 @@ -import itertools from collections.abc import Callable from collections.abc import Iterator -from typing import Any -from typing import cast from uuid import uuid4 from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk -from langchain_core.messages import HumanMessage +from langchain_core.messages import ToolCall -from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import LlmDoc -from danswer.chat.models import StreamStopInfo -from danswer.chat.models import StreamStopReason -from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.file_store.utils import InMemoryChatFile +from danswer.llm.answering.llm_response_handler import LLMCall +from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig -from danswer.llm.answering.models import StreamProcessor from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.build import default_build_system_message from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.prompts.citations_prompt import ( - build_citations_system_message, +from danswer.llm.answering.stream_processing.citation_response_handler import ( + CitationResponseHandler, ) -from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message -from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message -from danswer.llm.answering.stream_processing.citation_processing import ( - build_citation_processor, +from danswer.llm.answering.stream_processing.citation_response_handler import ( + DummyAnswerResponseHandler, ) -from danswer.llm.answering.stream_processing.quotes_processing import ( - build_quotes_processor, -) -from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.llm.answering.stream_processing.utils import map_document_id_order +from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler from danswer.llm.interfaces import LLM -from danswer.llm.interfaces import ToolChoiceOptions from danswer.natural_language_processing.utils import get_tokenizer -from danswer.tools.custom.custom_tool_prompt_builder import ( - build_user_message_for_custom_tool_for_non_tool_calling_llm, -) -from danswer.tools.force import filter_tools_for_force_tool_use from danswer.tools.force import ForceUseTool -from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID -from danswer.tools.images.image_generation_tool import ImageGenerationResponse -from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.images.prompt import build_image_generation_user_prompt -from danswer.tools.internet_search.internet_search_tool import InternetSearchTool -from danswer.tools.message import build_tool_message -from danswer.tools.message import ToolCallSummary -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID -from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID -from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID -from danswer.tools.search.search_tool import SearchResponseSummary -from danswer.tools.search.search_tool import SearchTool +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import ( - check_which_tools_should_run_for_non_tool_calling_llm, -) -from danswer.tools.tool_runner import ToolCallFinalResult +from danswer.tools.tool_implementations.search.search_tool import SearchTool from danswer.tools.tool_runner import ToolCallKickoff -from danswer.tools.tool_runner import ToolRunner -from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger @@ -74,29 +40,9 @@ logger = setup_logger() -def _get_answer_stream_processor( - context_docs: list[LlmDoc], - doc_id_to_rank_map: DocumentIdOrderMapping, - answer_style_configs: AnswerStyleConfig, -) -> StreamProcessor: - if answer_style_configs.citation_config: - return build_citation_processor( - context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map - ) - if answer_style_configs.quotes_config: - return build_quotes_processor( - context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak") - ) - - raise RuntimeError("Not implemented yet") - - AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] -logger = setup_logger() - - class Answer: def __init__( self, @@ -136,8 +82,6 @@ def __init__( self.tools = tools or [] self.force_use_tool = force_use_tool - self.skip_explicit_tool_calling = skip_explicit_tool_calling - self.message_history = message_history or [] # used for QA flow where we only want to send a single message self.single_message_history = single_message_history @@ -162,335 +106,132 @@ def __init__( self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation self._is_cancelled = False - def _update_prompt_builder_for_search_tool( - self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc] - ) -> None: - if self.answer_style_config.citation_config: - prompt_builder.update_system_prompt( - build_citations_system_message(self.prompt_config) - ) - prompt_builder.update_user_prompt( - build_citations_user_message( - question=self.question, - prompt_config=self.prompt_config, - context_docs=final_context_documents, - files=self.latest_query_files, - all_doc_useful=( - self.answer_style_config.citation_config.all_docs_useful - if self.answer_style_config.citation_config - else False - ), - history_message=self.single_message_history or "", - ) - ) - elif self.answer_style_config.quotes_config: - prompt_builder.update_user_prompt( - build_quotes_user_message( - question=self.question, - context_docs=final_context_documents, - history_str=self.single_message_history or "", - prompt=self.prompt_config, - ) + self.using_tool_calling_llm = ( + explicit_tool_calling_supported( + self.llm.config.model_provider, self.llm.config.model_name ) + and not skip_explicit_tool_calling + ) - def _raw_output_for_explicit_tool_calling_llms( - self, - ) -> Iterator[ - str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult - ]: - prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) - - tool_call_chunk: AIMessageChunk | None = None - if self.force_use_tool.force_use and self.force_use_tool.args is not None: - # if we are forcing a tool WITH args specified, we don't need to check which tools to run - # / need to generate the args - tool_call_chunk = AIMessageChunk( - content="", - ) - tool_call_chunk.tool_calls = [ - { - "name": self.force_use_tool.tool_name, - "args": self.force_use_tool.args, - "id": str(uuid4()), - } - ] - else: - # if tool calling is supported, first try the raw message - # to see if we don't need to use any tools - prompt_builder.update_system_prompt( - default_build_system_message(self.prompt_config) - ) - prompt_builder.update_user_prompt( - default_build_user_message( - self.question, self.prompt_config, self.latest_query_files - ) - ) - prompt = prompt_builder.build() - final_tool_definitions = [ - tool.tool_definition() - for tool in filter_tools_for_force_tool_use( - self.tools, self.force_use_tool - ) - ] - - for message in self.llm.stream( - prompt=prompt, - tools=final_tool_definitions if final_tool_definitions else None, - tool_choice="required" if self.force_use_tool.force_use else None, - structured_response_format=self.answer_style_config.structured_response_format, - ): - if isinstance(message, AIMessageChunk) and ( - message.tool_call_chunks or message.tool_calls - ): - if tool_call_chunk is None: - tool_call_chunk = message - else: - tool_call_chunk += message # type: ignore - else: - if message.content: - if self.is_cancelled: - return - yield cast(str, message.content) - if ( - message.additional_kwargs.get("usage_metadata", {}).get("stop") - == "length" - ): - yield StreamStopInfo( - stop_reason=StreamStopReason.CONTEXT_LENGTH - ) - - if not tool_call_chunk: - return # no tool call needed - - # if we have a tool call, we need to call the tool - tool_call_requests = tool_call_chunk.tool_calls - for tool_call_request in tool_call_requests: - known_tools_by_name = [ - tool for tool in self.tools if tool.name == tool_call_request["name"] - ] - - if not known_tools_by_name: - logger.error( - "Tool call requested with unknown name field. \n" - f"self.tools: {self.tools}" - f"tool_call_request: {tool_call_request}" - ) - if self.tools: - tool = self.tools[0] - else: - continue - else: - tool = known_tools_by_name[0] - tool_args = ( - self.force_use_tool.args - if self.force_use_tool.tool_name == tool.name - and self.force_use_tool.args - else tool_call_request["args"] - ) + def _get_tools_list(self) -> list[Tool]: + if not self.force_use_tool.force_use: + return self.tools - tool_runner = ToolRunner(tool, tool_args) - yield tool_runner.kickoff() - yield from tool_runner.tool_responses() + tool = next( + (t for t in self.tools if t.name == self.force_use_tool.tool_name), None + ) + if tool is None: + raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found") - tool_call_summary = ToolCallSummary( - tool_call_request=tool_call_chunk, - tool_call_result=build_tool_message( - tool_call_request, tool_runner.tool_message_content() - ), + logger.info( + f"Forcefully using tool='{tool.name}'" + + ( + f" with args='{self.force_use_tool.args}'" + if self.force_use_tool.args is not None + else "" ) + ) + return [tool] - if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: - self._update_prompt_builder_for_search_tool(prompt_builder, []) - elif tool.name == ImageGenerationTool._NAME: - img_urls = [ - img_generation_result["url"] - for img_generation_result in tool_runner.tool_final_result().tool_result - ] - prompt_builder.update_user_prompt( - build_image_generation_user_prompt( - query=self.question, img_urls=img_urls - ) - ) - yield tool_runner.tool_final_result() - if not self.skip_gen_ai_answer_generation: - prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - - yield from self._process_llm_stream( - prompt=prompt, - # as of now, we don't support multiple tool calls in sequence, which is why - # we don't need to pass this in here - # tools=[tool.tool_definition() for tool in self.tools], - ) + def _handle_specified_tool_call( + self, llm_calls: list[LLMCall], tool: Tool, tool_args: dict + ) -> AnswerStream: + current_llm_call = llm_calls[-1] - return + # make a dummy tool handler + tool_handler = ToolResponseHandler([tool]) - # This method processes the LLM stream and yields the content or stop information - def _process_llm_stream( - self, - prompt: Any, - tools: list[dict] | None = None, - tool_choice: ToolChoiceOptions | None = None, - ) -> Iterator[str | StreamStopInfo]: - for message in self.llm.stream( - prompt=prompt, - tools=tools, - tool_choice=tool_choice, - structured_response_format=self.answer_style_config.structured_response_format, - ): - if isinstance(message, AIMessageChunk): - if message.content: - if self.is_cancelled: - return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) - yield cast(str, message.content) - - if ( - message.additional_kwargs.get("usage_metadata", {}).get("stop") - == "length" - ): - yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH) - - def _raw_output_for_non_explicit_tool_calling_llms( - self, - ) -> Iterator[ - str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult - ]: - prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) - chosen_tool_and_args: tuple[Tool, dict] | None = None - - if self.force_use_tool.force_use: - # if we are forcing a tool, we don't need to check which tools to run - tool = next( - iter( - [ - tool - for tool in self.tools - if tool.name == self.force_use_tool.tool_name - ] - ), - None, - ) - if not tool: - raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found") + dummy_tool_call_chunk = AIMessageChunk(content="") + dummy_tool_call_chunk.tool_calls = [ + ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) + ] - tool_args = ( - self.force_use_tool.args - if self.force_use_tool.args is not None - else tool.get_args_for_non_tool_calling_llm( - query=self.question, - history=self.message_history, - llm=self.llm, - force_run=True, - ) - ) - - if tool_args is None: - raise RuntimeError(f"Tool '{tool.name}' did not return args") + response_handler_manager = LLMResponseHandlerManager( + tool_handler, DummyAnswerResponseHandler(), self.is_cancelled + ) + yield from response_handler_manager.handle_llm_response( + iter([dummy_tool_call_chunk]) + ) - chosen_tool_and_args = (tool, tool_args) + new_llm_call = response_handler_manager.next_llm_call(current_llm_call) + if new_llm_call: + yield from self._get_response(llm_calls + [new_llm_call]) else: - tool_options = check_which_tools_should_run_for_non_tool_calling_llm( - tools=self.tools, - query=self.question, - history=self.message_history, - llm=self.llm, - ) + raise RuntimeError("Tool call handler did not return a new LLM call") - available_tools_and_args = [ - (self.tools[ind], args) - for ind, args in enumerate(tool_options) - if args is not None - ] + def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: + current_llm_call = llm_calls[-1] - logger.info( - f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" + # handle the case where no decision has to be made; we simply run the tool + if ( + current_llm_call.force_use_tool.force_use + and current_llm_call.force_use_tool.args is not None + ): + tool_name, tool_args = ( + current_llm_call.force_use_tool.tool_name, + current_llm_call.force_use_tool.args, ) - - chosen_tool_and_args = ( - select_single_tool_for_non_tool_calling_llm( - tools_and_args=available_tools_and_args, - history=self.message_history, - query=self.question, - llm=self.llm, - ) - if available_tools_and_args - else None + tool = next( + (t for t in current_llm_call.tools if t.name == tool_name), None ) + if not tool: + raise RuntimeError(f"Tool '{tool_name}' not found") - logger.notice(f"Chosen tool: {chosen_tool_and_args}") + yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) + return - if not chosen_tool_and_args: - if self.skip_gen_ai_answer_generation: - raise ValueError( - "skip_gen_ai_answer_generation is True, but no tool was chosen; no answer will be generated" - ) - prompt_builder.update_system_prompt( - default_build_system_message(self.prompt_config) - ) - prompt_builder.update_user_prompt( - default_build_user_message( - self.question, self.prompt_config, self.latest_query_files + # special pre-logic for non-tool calling LLM case + if not self.using_tool_calling_llm and current_llm_call.tools: + chosen_tool_and_args = ( + ToolResponseHandler.get_tool_call_for_non_tool_calling_llm( + current_llm_call, self.llm ) ) - prompt = prompt_builder.build() - yield from self._process_llm_stream( - prompt=prompt, - tools=None, - ) + if chosen_tool_and_args: + tool, tool_args = chosen_tool_and_args + yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) + return + + # if we're skipping gen ai answer generation, we should break + # out unless we're forcing a tool call. If we don't, we might generate an + # answer, which is a no-no! + if ( + self.skip_gen_ai_answer_generation + and not current_llm_call.force_use_tool.force_use + ): return - tool, tool_args = chosen_tool_and_args - tool_runner = ToolRunner(tool, tool_args) - yield tool_runner.kickoff() + # set up "handlers" to listen to the LLM response stream and + # feed back the processed results + handle tool call requests + # + figure out what the next LLM call should be + tool_call_handler = ToolResponseHandler(current_llm_call.tools) - if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: - final_context_documents = None - for response in tool_runner.tool_responses(): - if response.id == FINAL_CONTEXT_DOCUMENTS_ID: - final_context_documents = cast(list[LlmDoc], response.response) - yield response - - if final_context_documents is None: - raise RuntimeError( - f"{tool.name} did not return final context documents" - ) + search_result = SearchTool.get_search_result(current_llm_call) or [] + citation_response_handler = CitationResponseHandler( + context_docs=search_result, + doc_id_to_rank_map=map_document_id_order(search_result), + ) - self._update_prompt_builder_for_search_tool( - prompt_builder, final_context_documents - ) - elif tool.name == ImageGenerationTool._NAME: - img_urls = [] - for response in tool_runner.tool_responses(): - if response.id == IMAGE_GENERATION_RESPONSE_ID: - img_generation_response = cast( - list[ImageGenerationResponse], response.response - ) - img_urls = [img.url for img in img_generation_response] - - yield response - - prompt_builder.update_user_prompt( - build_image_generation_user_prompt( - query=self.question, - img_urls=img_urls, - ) - ) - else: - prompt_builder.update_user_prompt( - HumanMessage( - content=build_user_message_for_custom_tool_for_non_tool_calling_llm( - self.question, - tool.name, - *tool_runner.tool_responses(), - ) - ) - ) - final = tool_runner.tool_final_result() + response_handler_manager = LLMResponseHandlerManager( + tool_call_handler, citation_response_handler, self.is_cancelled + ) - yield final - if not self.skip_gen_ai_answer_generation: - prompt = prompt_builder.build() + # DEBUG: good breakpoint + stream = self.llm.stream( + prompt=current_llm_call.prompt_builder.build(), + tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, + tool_choice=( + "required" + if current_llm_call.tools and current_llm_call.force_use_tool.force_use + else None + ), + structured_response_format=self.answer_style_config.structured_response_format, + ) + yield from response_handler_manager.handle_llm_response(stream) - yield from self._process_llm_stream(prompt=prompt, tools=None) + new_llm_call = response_handler_manager.next_llm_call(current_llm_call) + if new_llm_call: + yield from self._get_response(llm_calls + [new_llm_call]) @property def processed_streamed_output(self) -> AnswerStream: @@ -498,94 +239,30 @@ def processed_streamed_output(self) -> AnswerStream: yield from self._processed_stream return - output_generator = ( - self._raw_output_for_explicit_tool_calling_llms() - if explicit_tool_calling_supported( - self.llm.config.model_provider, self.llm.config.model_name - ) - and not self.skip_explicit_tool_calling - else self._raw_output_for_non_explicit_tool_calling_llms() + prompt_builder = AnswerPromptBuilder( + user_message=default_build_user_message( + user_query=self.question, + prompt_config=self.prompt_config, + files=self.latest_query_files, + ), + message_history=self.message_history, + llm_config=self.llm.config, + single_message_history=self.single_message_history, + ) + prompt_builder.update_system_prompt( + default_build_system_message(self.prompt_config) + ) + llm_call = LLMCall( + prompt_builder=prompt_builder, + tools=self._get_tools_list(), + force_use_tool=self.force_use_tool, + files=self.latest_query_files, + tool_call_info=[], + using_tool_calling_llm=self.using_tool_calling_llm, ) - - def _process_stream( - stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo], - ) -> AnswerStream: - message = None - - # special things we need to keep track of for the SearchTool - # raw results that will be displayed to the user - search_results: list[LlmDoc] | None = None - # processed docs to feed into the LLM - final_context_docs: list[LlmDoc] | None = None - - for message in stream: - if isinstance(message, ToolCallKickoff) or isinstance( - message, ToolCallFinalResult - ): - yield message - elif isinstance(message, ToolResponse): - if message.id == SEARCH_RESPONSE_SUMMARY_ID: - # We don't need to run section merging in this flow, this variable is only used - # below to specify the ordering of the documents for the purpose of matching - # citations to the right search documents. The deduplication logic is more lightweight - # there and we don't need to do it twice - search_results = [ - llm_doc_from_inference_section(section) - for section in cast( - SearchResponseSummary, message.response - ).top_sections - ] - elif message.id == FINAL_CONTEXT_DOCUMENTS_ID: - final_context_docs = cast(list[LlmDoc], message.response) - yield message - - elif ( - message.id == SEARCH_DOC_CONTENT_ID - and not self._return_contexts - ): - continue - - yield message - else: - # assumes all tool responses will come first, then the final answer - break - - if not self.skip_gen_ai_answer_generation: - process_answer_stream_fn = _get_answer_stream_processor( - context_docs=final_context_docs or [], - # if doc selection is enabled, then search_results will be None, - # so we need to use the final_context_docs - doc_id_to_rank_map=map_document_id_order( - search_results or final_context_docs or [] - ), - answer_style_configs=self.answer_style_config, - ) - - stream_stop_info = None - - def _stream() -> Iterator[str]: - nonlocal stream_stop_info - for item in itertools.chain([message], stream): - if isinstance(item, StreamStopInfo): - stream_stop_info = item - return - - # this should never happen, but we're seeing weird behavior here so handling for now - if not isinstance(item, str): - logger.error( - f"Received non-string item in answer stream: {item}. Skipping." - ) - continue - - yield item - - yield from process_answer_stream_fn(_stream()) - - if stream_stop_info: - yield stream_stop_info processed_stream = [] - for processed_packet in _process_stream(output_generator): + for processed_packet in self._get_response([llm_call]): processed_stream.append(processed_packet) yield processed_packet @@ -609,7 +286,6 @@ def citations(self) -> list[CitationInfo]: return citations - @property def is_cancelled(self) -> bool: if self._is_cancelled: return True diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py new file mode 100644 index 00000000000..6578e808952 --- /dev/null +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -0,0 +1,83 @@ +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterator +from typing import TYPE_CHECKING + +from langchain_core.messages import BaseMessage +from pydantic.v1 import BaseModel as BaseModel__v1 + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.tools.force import ForceUseTool +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool + + +if TYPE_CHECKING: + from danswer.llm.answering.stream_processing.citation_response_handler import ( + AnswerResponseHandler, + ) + from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler + + +ResponsePart = ( + DanswerAnswerPiece + | CitationInfo + | ToolCallKickoff + | ToolResponse + | ToolCallFinalResult + | StreamStopInfo +) + + +class LLMCall(BaseModel__v1): + prompt_builder: AnswerPromptBuilder + tools: list[Tool] + force_use_tool: ForceUseTool + files: list[InMemoryChatFile] + tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult] + using_tool_calling_llm: bool + + class Config: + arbitrary_types_allowed = True + + +class LLMResponseHandlerManager: + def __init__( + self, + tool_handler: "ToolResponseHandler", + answer_handler: "AnswerResponseHandler", + is_cancelled: Callable[[], bool], + ): + self.tool_handler = tool_handler + self.answer_handler = answer_handler + self.is_cancelled = is_cancelled + + def handle_llm_response( + self, + stream: Iterator[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + all_messages: list[BaseMessage] = [] + for message in stream: + # tool handler doesn't do anything until the full message is received + # NOTE: still need to run list() to get this to run + list(self.tool_handler.handle_response_part(message, all_messages)) + yield from self.answer_handler.handle_response_part(message, all_messages) + all_messages.append(message) + + if self.is_cancelled(): + yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + return + + # potentially give back all info on the selected tool call + its result + yield from self.tool_handler.handle_response_part(None, all_messages) + yield from self.answer_handler.handle_response_part(None, all_messages) + + def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None: + return self.tool_handler.next_llm_call(llm_call) diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py index f53d4481f6e..b5b774f522d 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/llm/answering/prompts/build.py @@ -12,12 +12,12 @@ from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_message_tokens +from danswer.llm.utils import message_to_prompt_and_imgs from danswer.llm.utils import translate_history_to_basemessages from danswer.natural_language_processing.utils import get_tokenizer from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.prompt_utils import add_date_time_to_prompt from danswer.prompts.prompt_utils import drop_messages_history_overflow -from danswer.tools.message import ToolCallSummary def default_build_system_message( @@ -54,18 +54,14 @@ def default_build_user_message( class AnswerPromptBuilder: def __init__( - self, message_history: list[PreviousMessage], llm_config: LLMConfig + self, + user_message: HumanMessage, + message_history: list[PreviousMessage], + llm_config: LLMConfig, + single_message_history: str | None = None, ) -> None: self.max_tokens = compute_max_llm_input_tokens(llm_config) - ( - self.message_history, - self.history_token_cnts, - ) = translate_history_to_basemessages(message_history) - - self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None - self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None - llm_tokenizer = get_tokenizer( provider_type=llm_config.model_provider, model_name=llm_config.model_name, @@ -74,6 +70,24 @@ def __init__( Callable[[str], list[int]], llm_tokenizer.encode ) + self.raw_message_history = message_history + ( + self.message_history, + self.history_token_cnts, + ) = translate_history_to_basemessages(message_history) + + # for cases where like the QA flow where we want to condense the chat history + # into a single message rather than a sequence of User / Assistant messages + self.single_message_history = single_message_history + + self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None + self.user_message_and_token_cnt = ( + user_message, + check_message_tokens(user_message, self.llm_tokenizer_encode_func), + ) + + self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] + def update_system_prompt(self, system_message: SystemMessage | None) -> None: if not system_message: self.system_message_and_token_cnt = None @@ -85,18 +99,21 @@ def update_system_prompt(self, system_message: SystemMessage | None) -> None: ) def update_user_prompt(self, user_message: HumanMessage) -> None: - if not user_message: - self.user_message_and_token_cnt = None - return - self.user_message_and_token_cnt = ( user_message, check_message_tokens(user_message, self.llm_tokenizer_encode_func), ) - def build( - self, tool_call_summary: ToolCallSummary | None = None - ) -> list[BaseMessage]: + def append_message(self, message: BaseMessage) -> None: + """Append a new message to the message history.""" + token_count = check_message_tokens(message, self.llm_tokenizer_encode_func) + self.new_messages_and_token_cnts.append((message, token_count)) + + def get_user_message_content(self) -> str: + query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0]) + return query + + def build(self) -> list[BaseMessage]: if not self.user_message_and_token_cnt: raise ValueError("User message must be set before building prompt") @@ -113,25 +130,8 @@ def build( final_messages_with_tokens.append(self.user_message_and_token_cnt) - if tool_call_summary: - final_messages_with_tokens.append( - ( - tool_call_summary.tool_call_request, - check_message_tokens( - tool_call_summary.tool_call_request, - self.llm_tokenizer_encode_func, - ), - ) - ) - final_messages_with_tokens.append( - ( - tool_call_summary.tool_call_result, - check_message_tokens( - tool_call_summary.tool_call_result, - self.llm_tokenizer_encode_func, - ), - ) - ) + if self.new_messages_and_token_cnts: + final_messages_with_tokens.extend(self.new_messages_and_token_cnts) return drop_messages_history_overflow( final_messages_with_tokens, self.max_tokens diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index a2248da0585..b7ca7797e88 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -6,7 +6,6 @@ from danswer.db.models import Persona from danswer.db.persona import get_default_prompt__read_only from danswer.db.search_settings import get_multilingual_expansion -from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import PromptConfig from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple @@ -14,6 +13,7 @@ from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import message_to_prompt_and_imgs from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT @@ -132,10 +132,9 @@ def build_citations_system_message( def build_citations_user_message( - question: str, + message: HumanMessage, prompt_config: PromptConfig, context_docs: list[LlmDoc] | list[InferenceChunk], - files: list[InMemoryChatFile], all_doc_useful: bool, history_message: str = "", ) -> HumanMessage: @@ -149,6 +148,7 @@ def build_citations_user_message( if history_message else "" ) + query, img_urls = message_to_prompt_and_imgs(message) if context_docs: context_docs_str = build_complete_context_str(context_docs) @@ -158,20 +158,22 @@ def build_citations_user_message( optional_ignore_statement=optional_ignore, context_docs_str=context_docs_str, task_prompt=task_prompt_with_reminder, - user_query=question, + user_query=query, history_block=history_block, ) else: # if no context docs provided, assume we're in the tool calling flow user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format( task_prompt=task_prompt_with_reminder, - user_query=question, + user_query=query, history_block=history_block, ) user_prompt = user_prompt.strip() user_msg = HumanMessage( - content=build_content_with_imgs(user_prompt, files) if files else user_prompt + content=build_content_with_imgs(user_prompt, img_urls=img_urls) + if img_urls + else user_prompt ) return user_msg diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py index 07abc4356b6..42f736b627d 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -5,6 +5,7 @@ from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.db.search_settings import get_multilingual_expansion from danswer.llm.answering.models import PromptConfig +from danswer.llm.utils import message_to_prompt_and_imgs from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK from danswer.prompts.direct_qa_prompts import JSON_PROMPT @@ -75,7 +76,7 @@ def _build_strong_llm_quotes_prompt( def build_quotes_user_message( - question: str, + message: HumanMessage, context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, prompt: PromptConfig, @@ -86,28 +87,10 @@ def build_quotes_user_message( else _build_strong_llm_quotes_prompt ) - return prompt_builder( - question=question, - context_docs=context_docs, - history_str=history_str, - prompt=prompt, - ) - - -def build_quotes_prompt( - question: str, - context_docs: list[LlmDoc] | list[InferenceChunk], - history_str: str, - prompt: PromptConfig, -) -> HumanMessage: - prompt_builder = ( - _build_weak_llm_quotes_prompt - if QA_PROMPT_OVERRIDE == "weak" - else _build_strong_llm_quotes_prompt - ) + query, _ = message_to_prompt_and_imgs(message) return prompt_builder( - question=question, + question=query, context_docs=context_docs, history_str=history_str, prompt=prompt, diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/llm/answering/prune_and_merge.py index 0193de1f2aa..690a5d2280d 100644 --- a/backend/danswer/llm/answering/prune_and_merge.py +++ b/backend/danswer/llm/answering/prune_and_merge.py @@ -19,7 +19,7 @@ from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection -from danswer.tools.search.search_utils import section_to_dict +from danswer.tools.tool_implementations.search.search_utils import section_to_dict from danswer.utils.logger import setup_logger diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index f1e5489550d..950ad207878 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -1,12 +1,10 @@ import re -from collections.abc import Iterator +from collections.abc import Generator -from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.llm.answering.models import StreamProcessor from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger @@ -19,128 +17,104 @@ def in_code_block(llm_text: str) -> bool: return count % 2 != 0 -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: DocumentIdOrderMapping, - stop_stream: str | None = STOP_STREAM_PAT, -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - """ - Key aspects: - - 1. Stream Processing: - - Processes tokens one by one, allowing for real-time handling of large texts. - - 2. Citation Detection: - - Uses regex to find citations in the format [number]. - - Example: [1], [2], etc. - - 3. Citation Mapping: - - Maps detected citation numbers to actual document ranks using doc_id_to_rank_map. - - Example: [1] might become [3] if doc_id_to_rank_map maps it to 3. - - 4. Citation Formatting: - - Replaces citations with properly formatted versions. - - Adds links if available: [[1]](https://example.com) - - Handles cases where links are not available: [[1]]() - - 5. Duplicate Handling: - - Skips consecutive citations of the same document to avoid redundancy. - - 6. Output Generation: - - Yields DanswerAnswerPiece objects for regular text. - - Yields CitationInfo objects for each unique citation encountered. - - 7. Context Awareness: - - Uses context_docs to access document information for citations. - - This function effectively processes a stream of text, identifies and reformats citations, - and provides both the processed text and citation information as output. - """ - order_mapping = doc_id_to_rank_map.order_mapping - llm_out = "" - max_citation_num = len(context_docs) - citation_order = [] - curr_segment = "" - cited_inds = set() - hold = "" - - raw_out = "" - current_citations: list[int] = [] - past_cite_count = 0 - for raw_token in tokens: - raw_out += raw_token - if stop_stream: - next_hold = hold + raw_token - if stop_stream in next_hold: - break - if next_hold == stop_stream[: len(next_hold)]: - hold = next_hold - continue +class CitationProcessor: + def __init__( + self, + context_docs: list[LlmDoc], + doc_id_to_rank_map: DocumentIdOrderMapping, + stop_stream: str | None = STOP_STREAM_PAT, + ): + self.context_docs = context_docs + self.doc_id_to_rank_map = doc_id_to_rank_map + self.stop_stream = stop_stream + self.order_mapping = doc_id_to_rank_map.order_mapping + self.llm_out = "" + self.max_citation_num = len(context_docs) + self.citation_order: list[int] = [] + self.curr_segment = "" + self.cited_inds: set[int] = set() + self.hold = "" + self.current_citations: list[int] = [] + self.past_cite_count = 0 + + def process_token( + self, token: str | None + ) -> Generator[DanswerAnswerPiece | CitationInfo, None, None]: + # None -> end of stream + if token is None: + yield DanswerAnswerPiece(answer_piece=self.curr_segment) + return + + if self.stop_stream: + next_hold = self.hold + token + if self.stop_stream in next_hold: + return + if next_hold == self.stop_stream[: len(next_hold)]: + self.hold = next_hold + return token = next_hold - hold = "" - else: - token = raw_token + self.hold = "" - curr_segment += token - llm_out += token + self.curr_segment += token + self.llm_out += token # Handle code blocks without language tags - if "`" in curr_segment: - if curr_segment.endswith("`"): - continue - elif "```" in curr_segment: - piece_that_comes_after = curr_segment.split("```")[1][0] - if piece_that_comes_after == "\n" and in_code_block(llm_out): - curr_segment = curr_segment.replace("```", "```plaintext") + if "`" in self.curr_segment: + if self.curr_segment.endswith("`"): + return + elif "```" in self.curr_segment: + piece_that_comes_after = self.curr_segment.split("```")[1][0] + if piece_that_comes_after == "\n" and in_code_block(self.llm_out): + self.curr_segment = self.curr_segment.replace("```", "```plaintext") citation_pattern = r"\[(\d+)\]" - - citations_found = list(re.finditer(citation_pattern, curr_segment)) + citations_found = list(re.finditer(citation_pattern, self.curr_segment)) possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) + possible_citation_found = re.search( + possible_citation_pattern, self.curr_segment + ) - # `past_cite_count`: number of characters since past citation - # 5 to ensure a citation hasn't occured - if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5: - current_citations = [] + if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5: + self.current_citations = [] - if citations_found and not in_code_block(llm_out): + result = "" # Initialize result here + if citations_found and not in_code_block(self.llm_out): last_citation_end = 0 length_to_add = 0 while len(citations_found) > 0: citation = citations_found.pop(0) numerical_value = int(citation.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[numerical_value - 1] - real_citation_num = order_mapping[context_llm_doc.document_id] + if 1 <= numerical_value <= self.max_citation_num: + context_llm_doc = self.context_docs[numerical_value - 1] + real_citation_num = self.order_mapping[context_llm_doc.document_id] - if real_citation_num not in citation_order: - citation_order.append(real_citation_num) + if real_citation_num not in self.citation_order: + self.citation_order.append(real_citation_num) - target_citation_num = citation_order.index(real_citation_num) + 1 + target_citation_num = ( + self.citation_order.index(real_citation_num) + 1 + ) # Skip consecutive citations of the same work - if target_citation_num in current_citations: + if target_citation_num in self.current_citations: start, end = citation.span() real_start = length_to_add + start diff = end - start - curr_segment = ( - curr_segment[: length_to_add + start] - + curr_segment[real_start + diff :] + self.curr_segment = ( + self.curr_segment[: length_to_add + start] + + self.curr_segment[real_start + diff :] ) length_to_add -= diff continue # Handle edge case where LLM outputs citation itself - # by allowing it to generate citations on its own. - if curr_segment.startswith("[["): - match = re.match(r"\[\[(\d+)\]\]", curr_segment) + if self.curr_segment.startswith("[["): + match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) if match: try: doc_id = int(match.group(1)) - context_llm_doc = context_docs[doc_id - 1] + context_llm_doc = self.context_docs[doc_id - 1] yield CitationInfo( citation_num=target_citation_num, document_id=context_llm_doc.document_id, @@ -150,75 +124,57 @@ def extract_citations_from_stream( f"Manual LLM citation didn't properly cite documents {e}" ) else: - # Will continue attempt on next loops logger.warning( "Manual LLM citation wasn't able to close brackets" ) - continue link = context_llm_doc.link # Replace the citation in the current segment start, end = citation.span() - curr_segment = ( - curr_segment[: start + length_to_add] + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + f"[{target_citation_num}]" - + curr_segment[end + length_to_add :] + + self.curr_segment[end + length_to_add :] ) - past_cite_count = len(llm_out) - current_citations.append(target_citation_num) + self.past_cite_count = len(self.llm_out) + self.current_citations.append(target_citation_num) - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) + if target_citation_num not in self.cited_inds: + self.cited_inds.add(target_citation_num) yield CitationInfo( citation_num=target_citation_num, document_id=context_llm_doc.document_id, ) if link: - prev_length = len(curr_segment) - curr_segment = ( - curr_segment[: start + length_to_add] + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + f"[[{target_citation_num}]]({link})" - + curr_segment[end + length_to_add :] + + self.curr_segment[end + length_to_add :] ) - length_to_add += len(curr_segment) - prev_length - + length_to_add += len(self.curr_segment) - prev_length else: - prev_length = len(curr_segment) - curr_segment = ( - curr_segment[: start + length_to_add] + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + f"[[{target_citation_num}]]()" - + curr_segment[end + length_to_add :] + + self.curr_segment[end + length_to_add :] ) - length_to_add += len(curr_segment) - prev_length + length_to_add += len(self.curr_segment) - prev_length last_citation_end = end + length_to_add if last_citation_end > 0: - yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end]) - curr_segment = curr_segment[last_citation_end:] - if possible_citation_found: - continue - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - -def build_citation_processor( - context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping -) -> StreamProcessor: - def stream_processor( - tokens: Iterator[str], - ) -> AnswerQuestionStreamReturn: - yield from extract_citations_from_stream( - tokens=tokens, - context_docs=context_docs, - doc_id_to_rank_map=doc_id_to_rank_map, - ) + result += self.curr_segment[:last_citation_end] + self.curr_segment = self.curr_segment[last_citation_end:] + + if not possible_citation_found: + result += self.curr_segment + self.curr_segment = "" - return stream_processor + if result: + yield DanswerAnswerPiece(answer_piece=result) diff --git a/backend/danswer/llm/answering/stream_processing/citation_response_handler.py b/backend/danswer/llm/answering/stream_processing/citation_response_handler.py new file mode 100644 index 00000000000..07a342fbd7e --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/citation_response_handler.py @@ -0,0 +1,61 @@ +import abc +from collections.abc import Generator + +from langchain_core.messages import BaseMessage + +from danswer.chat.models import CitationInfo +from danswer.chat.models import LlmDoc +from danswer.llm.answering.llm_response_handler import ResponsePart +from danswer.llm.answering.stream_processing.citation_processing import ( + CitationProcessor, +) +from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping + + +class AnswerResponseHandler(abc.ABC): + @abc.abstractmethod + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + raise NotImplementedError + + +class DummyAnswerResponseHandler(AnswerResponseHandler): + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + # This is a dummy handler that returns nothing + yield from [] + + +class CitationResponseHandler(AnswerResponseHandler): + def __init__( + self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping + ): + self.context_docs = context_docs + self.doc_id_to_rank_map = doc_id_to_rank_map + self.citation_processor = CitationProcessor( + context_docs=self.context_docs, + doc_id_to_rank_map=self.doc_id_to_rank_map, + ) + self.processed_text = "" + self.citations: list[CitationInfo] = [] + + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + if response_item is None: + return + + content = ( + response_item.content if isinstance(response_item.content, str) else "" + ) + + # Process the new content through the citation processor + yield from self.citation_processor.process_token(content) diff --git a/backend/danswer/llm/answering/tool/tool_response_handler.py b/backend/danswer/llm/answering/tool/tool_response_handler.py new file mode 100644 index 00000000000..6c4fec77941 --- /dev/null +++ b/backend/danswer/llm/answering/tool/tool_response_handler.py @@ -0,0 +1,205 @@ +from collections.abc import Generator + +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import BaseMessage +from langchain_core.messages import ToolCall + +from danswer.llm.answering.llm_response_handler import LLMCall +from danswer.llm.answering.llm_response_handler import ResponsePart +from danswer.llm.interfaces import LLM +from danswer.tools.force import ForceUseTool +from danswer.tools.message import build_tool_message +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from danswer.tools.tool import Tool +from danswer.tools.tool_runner import ( + check_which_tools_should_run_for_non_tool_calling_llm, +) +from danswer.tools.tool_runner import ToolRunner +from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +class ToolResponseHandler: + def __init__(self, tools: list[Tool]): + self.tools = tools + + self.tool_call_chunk: AIMessageChunk | None = None + self.tool_call_requests: list[ToolCall] = [] + + self.tool_runner: ToolRunner | None = None + self.tool_call_summary: ToolCallSummary | None = None + + self.tool_kickoff: ToolCallKickoff | None = None + self.tool_responses: list[ToolResponse] = [] + self.tool_final_result: ToolCallFinalResult | None = None + + @classmethod + def get_tool_call_for_non_tool_calling_llm( + cls, llm_call: LLMCall, llm: LLM + ) -> tuple[Tool, dict] | None: + if llm_call.force_use_tool.force_use: + # if we are forcing a tool, we don't need to check which tools to run + tool = next( + ( + t + for t in llm_call.tools + if t.name == llm_call.force_use_tool.tool_name + ), + None, + ) + if not tool: + raise RuntimeError( + f"Tool '{llm_call.force_use_tool.tool_name}' not found" + ) + + tool_args = ( + llm_call.force_use_tool.args + if llm_call.force_use_tool.args is not None + else tool.get_args_for_non_tool_calling_llm( + query=llm_call.prompt_builder.get_user_message_content(), + history=llm_call.prompt_builder.raw_message_history, + llm=llm, + force_run=True, + ) + ) + + if tool_args is None: + raise RuntimeError(f"Tool '{tool.name}' did not return args") + + return (tool, tool_args) + else: + tool_options = check_which_tools_should_run_for_non_tool_calling_llm( + tools=llm_call.tools, + query=llm_call.prompt_builder.get_user_message_content(), + history=llm_call.prompt_builder.raw_message_history, + llm=llm, + ) + + available_tools_and_args = [ + (llm_call.tools[ind], args) + for ind, args in enumerate(tool_options) + if args is not None + ] + + logger.info( + f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}" + ) + + chosen_tool_and_args = ( + select_single_tool_for_non_tool_calling_llm( + tools_and_args=available_tools_and_args, + history=llm_call.prompt_builder.raw_message_history, + query=llm_call.prompt_builder.get_user_message_content(), + llm=llm, + ) + if available_tools_and_args + else None + ) + + logger.notice(f"Chosen tool: {chosen_tool_and_args}") + return chosen_tool_and_args + + def _handle_tool_call(self) -> Generator[ResponsePart, None, None]: + if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls: + return + + self.tool_call_requests = self.tool_call_chunk.tool_calls + + selected_tool: Tool | None = None + selected_tool_call_request: ToolCall | None = None + for tool_call_request in self.tool_call_requests: + known_tools_by_name = [ + tool for tool in self.tools if tool.name == tool_call_request["name"] + ] + + if not known_tools_by_name: + logger.error( + "Tool call requested with unknown name field. \n" + f"self.tools: {self.tools}" + f"tool_call_request: {tool_call_request}" + ) + continue + else: + selected_tool = known_tools_by_name[0] + selected_tool_call_request = tool_call_request + + if selected_tool and selected_tool_call_request: + break + + if not selected_tool or not selected_tool_call_request: + return + + self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"]) + self.tool_call_summary = ToolCallSummary( + tool_call_request=self.tool_call_chunk, + tool_call_result=build_tool_message( + tool_call_request, self.tool_runner.tool_message_content() + ), + ) + + self.tool_kickoff = self.tool_runner.kickoff() + yield self.tool_kickoff + + for response in self.tool_runner.tool_responses(): + self.tool_responses.append(response) + yield response + + self.tool_final_result = self.tool_runner.tool_final_result() + yield self.tool_final_result + + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + if response_item is None: + yield from self._handle_tool_call() + + if isinstance(response_item, AIMessageChunk) and ( + response_item.tool_call_chunks or response_item.tool_calls + ): + if self.tool_call_chunk is None: + self.tool_call_chunk = response_item + else: + self.tool_call_chunk += response_item # type: ignore + + return + + def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None: + if ( + self.tool_runner is None + or self.tool_call_summary is None + or self.tool_kickoff is None + or self.tool_final_result is None + ): + return None + + tool_runner = self.tool_runner + new_prompt_builder = tool_runner.tool.build_next_prompt( + prompt_builder=current_llm_call.prompt_builder, + tool_call_summary=self.tool_call_summary, + tool_responses=self.tool_responses, + using_tool_calling_llm=current_llm_call.using_tool_calling_llm, + ) + return LLMCall( + prompt_builder=new_prompt_builder, + tools=[], # for now, only allow one tool call per response + force_use_tool=ForceUseTool( + force_use=False, + tool_name="", + args=None, + ), + files=current_llm_call.files, + using_tool_calling_llm=current_llm_call.using_tool_calling_llm, + tool_call_info=[ + self.tool_kickoff, + *self.tool_responses, + self.tool_final_result, + ], + ) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index bad18214b95..af480f83955 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -203,6 +203,28 @@ def build_content_with_imgs( ) +def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]: + if isinstance(message.content, str): + return message.content, [] + + imgs = [] + texts = [] + for part in message.content: + if isinstance(part, dict): + if part.get("type") == "image_url": + img_url = part.get("image_url", {}).get("url") + if img_url: + imgs.append(img_url) + elif part.get("type") == "text": + text = part.get("text") + if text: + texts.append(text) + else: + texts.append(part) + + return "".join(texts), imgs + + def dict_based_prompt_to_langchain_prompt( messages: list[dict[str, str]] ) -> list[BaseMessage]: diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 1bfac570aee..9ece5f4bba2 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -52,12 +52,16 @@ from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.utils import get_json_line from danswer.tools.force import ForceUseTool -from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID -from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID -from danswer.tools.search.search_tool import SearchResponseSummary -from danswer.tools.search.search_tool import SearchTool -from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID -from danswer.tools.tool import ToolResponse +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID +from danswer.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) +from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.tool_implementations.search.search_tool import ( + SECTION_RELEVANCE_LIST_ID, +) from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -202,30 +206,33 @@ def stream_answer_objects( max_tokens=max_document_tokens, ) + answer_config = AnswerStyleConfig( + citation_config=CitationConfig() if use_citations else None, + quotes_config=QuotesConfig() if not use_citations else None, + document_pruning_config=document_pruning_config, + ) + search_tool = SearchTool( db_session=db_session, user=user, - evaluation_type=LLMEvaluationType.SKIP - if DISABLE_LLM_DOC_RELEVANCE - else query_req.evaluation_type, + evaluation_type=( + LLMEvaluationType.SKIP + if DISABLE_LLM_DOC_RELEVANCE + else query_req.evaluation_type + ), persona=persona, retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, llm=llm, fast_llm=fast_llm, pruning_config=document_pruning_config, + answer_style_config=answer_config, bypass_acl=bypass_acl, chunks_above=query_req.chunks_above, chunks_below=query_req.chunks_below, full_doc=query_req.full_doc, ) - answer_config = AnswerStyleConfig( - citation_config=CitationConfig() if use_citations else None, - quotes_config=QuotesConfig() if not use_citations else None, - document_pruning_config=document_pruning_config, - ) - answer = Answer( question=query_msg.message, answer_style_config=answer_config, diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 0e69777a02c..5fa99952b1f 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -9,7 +9,7 @@ from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot -from danswer.server.features.tool.api import ToolSnapshot +from danswer.server.features.tool.models import ToolSnapshot from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 7e15c048826..48f857780ba 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -18,10 +18,16 @@ from danswer.server.features.tool.models import CustomToolCreate from danswer.server.features.tool.models import CustomToolUpdate from danswer.server.features.tool.models import ToolSnapshot -from danswer.tools.custom.openapi_parsing import MethodSpec -from danswer.tools.custom.openapi_parsing import openapi_to_method_specs -from danswer.tools.custom.openapi_parsing import validate_openapi_schema -from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + openapi_to_method_specs, +) +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + validate_openapi_schema, +) +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) from danswer.tools.utils import is_image_generation_available router = APIRouter(prefix="/tool") diff --git a/backend/danswer/tools/base_tool.py b/backend/danswer/tools/base_tool.py new file mode 100644 index 00000000000..73902504462 --- /dev/null +++ b/backend/danswer/tools/base_tool.py @@ -0,0 +1,59 @@ +from typing import cast +from typing import TYPE_CHECKING + +from langchain_core.messages import HumanMessage + +from danswer.llm.utils import message_to_prompt_and_imgs +from danswer.tools.tool import Tool + +if TYPE_CHECKING: + from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.tools.tool_implementations.custom.custom_tool import ( + CustomToolCallSummary, + ) + from danswer.tools.message import ToolCallSummary + from danswer.tools.models import ToolResponse + + +def build_user_message_for_non_tool_calling_llm( + message: HumanMessage, + tool_name: str, + *args: "ToolResponse", +) -> str: + query, _ = message_to_prompt_and_imgs(message) + + tool_run_summary = cast("CustomToolCallSummary", args[0].response).tool_result + return f""" +Here's the result from the {tool_name} tool: + +{tool_run_summary} + +Now respond to the following: + +{query} +""".strip() + + +class BaseTool(Tool): + def build_next_prompt( + self, + prompt_builder: "AnswerPromptBuilder", + tool_call_summary: "ToolCallSummary", + tool_responses: list["ToolResponse"], + using_tool_calling_llm: bool, + ) -> "AnswerPromptBuilder": + if using_tool_calling_llm: + prompt_builder.append_message(tool_call_summary.tool_call_request) + prompt_builder.append_message(tool_call_summary.tool_call_result) + else: + prompt_builder.update_user_prompt( + HumanMessage( + content=build_user_message_for_non_tool_calling_llm( + prompt_builder.user_message_and_token_cnt[0], + self.name, + *tool_responses, + ) + ) + ) + + return prompt_builder diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 99b2ae3bbb6..fb64381f1d0 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -9,9 +9,13 @@ from danswer.db.models import Persona from danswer.db.models import Tool as ToolDBModel -from danswer.tools.images.image_generation_tool import ImageGenerationTool -from danswer.tools.internet_search.internet_search_tool import InternetSearchTool -from danswer.tools.search.search_tool import SearchTool +from danswer.tools.tool_implementations.images.image_generation_tool import ( + ImageGenerationTool, +) +from danswer.tools.tool_implementations.internet_search.internet_search_tool import ( + InternetSearchTool, +) +from danswer.tools.tool_implementations.search.search_tool import SearchTool from danswer.tools.tool import Tool from danswer.utils.logger import setup_logger diff --git a/backend/danswer/tools/custom/custom_tool_prompt_builder.py b/backend/danswer/tools/custom/custom_tool_prompt_builder.py deleted file mode 100644 index 8016363acc9..00000000000 --- a/backend/danswer/tools/custom/custom_tool_prompt_builder.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import cast - -from danswer.tools.custom.custom_tool import CustomToolCallSummary -from danswer.tools.models import ToolResponse - - -def build_user_message_for_custom_tool_for_non_tool_calling_llm( - query: str, - tool_name: str, - *args: ToolResponse, -) -> str: - tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result - return f""" -Here's the result from the {tool_name} tool: - -{tool_run_summary} - -Now respond to the following: - -{query} -""".strip() diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index 29e5311fc15..1b1c43ab8da 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -1,11 +1,17 @@ import abc from collections.abc import Generator from typing import Any +from typing import TYPE_CHECKING from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM -from danswer.tools.models import ToolResponse + + +if TYPE_CHECKING: + from danswer.llm.answering.prompts.build import AnswerPromptBuilder + from danswer.tools.message import ToolCallSummary + from danswer.tools.models import ToolResponse class Tool(abc.ABC): @@ -32,7 +38,7 @@ def tool_definition(self) -> dict: @abc.abstractmethod def build_tool_message_content( - self, *args: ToolResponse + self, *args: "ToolResponse" ) -> str | list[str | dict[str, Any]]: raise NotImplementedError @@ -51,13 +57,26 @@ def get_args_for_non_tool_calling_llm( """Actual execution of the tool""" @abc.abstractmethod - def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]: raise NotImplementedError @abc.abstractmethod - def final_result(self, *args: ToolResponse) -> JSON_ro: + def final_result(self, *args: "ToolResponse") -> JSON_ro: """ This is the "final summary" result of the tool. It is the result that will be stored in the database. """ raise NotImplementedError + + """Some tools may want to modify the prompt based on the tool call summary and tool responses. + Default behavior is to continue with just the raw tool call request/result passed to the LLM.""" + + @abc.abstractmethod + def build_next_prompt( + self, + prompt_builder: "AnswerPromptBuilder", + tool_call_summary: "ToolCallSummary", + tool_responses: list["ToolResponse"], + using_tool_calling_llm: bool, + ) -> "AnswerPromptBuilder": + raise NotImplementedError diff --git a/backend/danswer/tools/custom/base_tool_types.py b/backend/danswer/tools/tool_implementations/custom/base_tool_types.py similarity index 100% rename from backend/danswer/tools/custom/base_tool_types.py rename to backend/danswer/tools/tool_implementations/custom/base_tool_types.py diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py similarity index 88% rename from backend/danswer/tools/custom/custom_tool.py rename to backend/danswer/tools/tool_implementations/custom/custom_tool.py index ee431af70e1..a1fb4bb699e 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -11,24 +11,34 @@ from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM -from danswer.tools.custom.base_tool_types import ToolResultType -from danswer.tools.custom.custom_tool_prompts import ( - SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, -) -from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT -from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT -from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT -from danswer.tools.custom.custom_tool_prompts import USE_TOOL -from danswer.tools.custom.openapi_parsing import MethodSpec -from danswer.tools.custom.openapi_parsing import openapi_to_method_specs -from danswer.tools.custom.openapi_parsing import openapi_to_url -from danswer.tools.custom.openapi_parsing import REQUEST_BODY -from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.base_tool import BaseTool from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER from danswer.tools.models import DynamicSchemaInfo from danswer.tools.models import MESSAGE_ID_PLACEHOLDER -from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_USER_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + TOOL_ARG_SYSTEM_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( + TOOL_ARG_USER_PROMPT, +) +from danswer.tools.tool_implementations.custom.custom_tool_prompts import USE_TOOL +from danswer.tools.tool_implementations.custom.openapi_parsing import MethodSpec +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + openapi_to_method_specs, +) +from danswer.tools.tool_implementations.custom.openapi_parsing import openapi_to_url +from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BODY +from danswer.tools.tool_implementations.custom.openapi_parsing import ( + validate_openapi_schema, +) from danswer.utils.headers import header_list_to_header_dict from danswer.utils.headers import HeaderItemDict from danswer.utils.logger import setup_logger @@ -43,7 +53,7 @@ class CustomToolCallSummary(BaseModel): tool_result: ToolResultType -class CustomTool(Tool): +class CustomTool(BaseTool): def __init__( self, method_spec: MethodSpec, diff --git a/backend/danswer/tools/custom/custom_tool_prompts.py b/backend/danswer/tools/tool_implementations/custom/custom_tool_prompts.py similarity index 100% rename from backend/danswer/tools/custom/custom_tool_prompts.py rename to backend/danswer/tools/tool_implementations/custom/custom_tool_prompts.py diff --git a/backend/danswer/tools/custom/openapi_parsing.py b/backend/danswer/tools/tool_implementations/custom/openapi_parsing.py similarity index 100% rename from backend/danswer/tools/custom/openapi_parsing.py rename to backend/danswer/tools/tool_implementations/custom/openapi_parsing.py diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py similarity index 86% rename from backend/danswer/tools/images/image_generation_tool.py rename to backend/danswer/tools/tool_implementations/images/image_generation_tool.py index 3584d50f77e..6fb06fb534a 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py @@ -11,12 +11,17 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.images.prompt import ( + build_image_generation_user_prompt, +) from danswer.utils.headers import build_llm_extra_headers from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel @@ -258,3 +263,34 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: image_generation_response.model_dump() for image_generation_response in image_generation_responses ] + + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + img_generation_response = cast( + list[ImageGenerationResponse] | None, + next( + ( + response.response + for response in tool_responses + if response.id == IMAGE_GENERATION_RESPONSE_ID + ), + None, + ), + ) + if img_generation_response is None: + raise ValueError("No image generation response found") + + img_urls = [img.url for img in img_generation_response] + prompt_builder.update_user_prompt( + build_image_generation_user_prompt( + query=prompt_builder.get_user_message_content(), + img_urls=img_urls, + ) + ) + + return prompt_builder diff --git a/backend/danswer/tools/images/prompt.py b/backend/danswer/tools/tool_implementations/images/prompt.py similarity index 100% rename from backend/danswer/tools/images/prompt.py rename to backend/danswer/tools/tool_implementations/images/prompt.py diff --git a/backend/danswer/tools/internet_search/internet_search_tool.py b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py similarity index 81% rename from backend/danswer/tools/internet_search/internet_search_tool.py rename to backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py index 70b4483b996..12142bc4852 100644 --- a/backend/danswer/tools/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/tool_implementations/internet_search/internet_search_tool.py @@ -11,18 +11,31 @@ from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.key_value_store.interface import JSON_ro +from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.search.models import SearchDoc from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase -from danswer.tools.internet_search.models import InternetSearchResponse -from danswer.tools.internet_search.models import InternetSearchResult -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.internet_search.models import ( + InternetSearchResponse, +) +from danswer.tools.tool_implementations.internet_search.models import ( + InternetSearchResult, +) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + build_next_prompt_for_search_like_tool, +) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) from danswer.utils.logger import setup_logger logger = setup_logger() @@ -97,8 +110,17 @@ class InternetSearchTool(Tool): _DISPLAY_NAME = "[Beta] Internet Search Tool" _DESCRIPTION = "Perform an internet search for up-to-date information." - def __init__(self, api_key: str, num_results: int = 10) -> None: + def __init__( + self, + api_key: str, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, + num_results: int = 10, + ) -> None: self.api_key = api_key + self.answer_style_config = answer_style_config + self.prompt_config = prompt_config + self.host = "https://api.bing.microsoft.com/v7.0" self.headers = { "Ocp-Apim-Subscription-Key": api_key, @@ -231,3 +253,19 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: def final_result(self, *args: ToolResponse) -> JSON_ro: search_response = cast(InternetSearchResponse, args[0].response) return search_response.model_dump() + + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + return build_next_prompt_for_search_like_tool( + prompt_builder=prompt_builder, + tool_call_summary=tool_call_summary, + tool_responses=tool_responses, + using_tool_calling_llm=using_tool_calling_llm, + answer_style_config=self.answer_style_config, + prompt_config=self.prompt_config, + ) diff --git a/backend/danswer/tools/internet_search/models.py b/backend/danswer/tools/tool_implementations/internet_search/models.py similarity index 100% rename from backend/danswer/tools/internet_search/models.py rename to backend/danswer/tools/tool_implementations/internet_search/models.py diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/tool_implementations/search/search_tool.py similarity index 87% rename from backend/danswer/tools/search/search_tool.py rename to backend/danswer/tools/tool_implementations/search/search_tool.py index 96ab7b843f6..6eda3013ab3 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/tool_implementations/search/search_tool.py @@ -17,10 +17,13 @@ from danswer.db.models import Persona from danswer.db.models import User from danswer.key_value_store.interface import JSON_ro +from danswer.llm.answering.llm_response_handler import LLMCall +from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import ContextualPruningConfig from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens from danswer.llm.answering.prune_and_merge import prune_and_merge_sections from danswer.llm.answering.prune_and_merge import prune_sections @@ -35,9 +38,16 @@ from danswer.search.pipeline import SearchPipeline from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase -from danswer.tools.search.search_utils import llm_doc_to_dict +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.search.search_utils import llm_doc_to_dict +from danswer.tools.tool_implementations.search_like_tool_utils import ( + build_next_prompt_for_search_like_tool, +) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) from danswer.utils.logger import setup_logger logger = setup_logger() @@ -45,7 +55,6 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" SEARCH_DOC_CONTENT_ID = "search_doc_content" SECTION_RELEVANCE_LIST_ID = "section_relevance_list" -FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" SEARCH_EVALUATION_ID = "llm_doc_eval" @@ -85,6 +94,7 @@ def __init__( llm: LLM, fast_llm: LLM, pruning_config: DocumentPruningConfig, + answer_style_config: AnswerStyleConfig, evaluation_type: LLMEvaluationType, # if specified, will not actually run a search and will instead return these # sections. Used when the user selects specific docs to talk to @@ -136,6 +146,7 @@ def __init__( num_chunk_multiple = self.chunks_above + self.chunks_below + 1 + self.answer_style_config = answer_style_config self.contextual_pruning_config = ( ContextualPruningConfig.from_doc_pruning_config( num_chunk_multiple=num_chunk_multiple, doc_pruning_config=pruning_config @@ -353,4 +364,36 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: # NOTE: need to do this json.loads(doc.json()) stuff because there are some # subfields that are not serializable by default (datetime) # this forces pydantic to make them JSON serializable for us - return [json.loads(doc.json()) for doc in final_docs] + return [json.loads(doc.model_dump_json()) for doc in final_docs] + + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + return build_next_prompt_for_search_like_tool( + prompt_builder=prompt_builder, + tool_call_summary=tool_call_summary, + tool_responses=tool_responses, + using_tool_calling_llm=using_tool_calling_llm, + answer_style_config=self.answer_style_config, + prompt_config=self.prompt_config, + ) + + """Other utility functions""" + + @classmethod + def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None: + if not llm_call.tool_call_info: + return None + + for yield_item in llm_call.tool_call_info: + if ( + isinstance(yield_item, ToolResponse) + and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID + ): + return cast(list[LlmDoc], yield_item.response) + + return None diff --git a/backend/danswer/tools/search/search_utils.py b/backend/danswer/tools/tool_implementations/search/search_utils.py similarity index 100% rename from backend/danswer/tools/search/search_utils.py rename to backend/danswer/tools/tool_implementations/search/search_utils.py diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py new file mode 100644 index 00000000000..6701f1602ea --- /dev/null +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -0,0 +1,71 @@ +from typing import cast + +from danswer.chat.models import LlmDoc +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.llm.answering.prompts.citations_prompt import ( + build_citations_system_message, +) +from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message +from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message +from danswer.tools.message import ToolCallSummary +from danswer.tools.models import ToolResponse + + +FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" + + +def build_next_prompt_for_search_like_tool( + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, +) -> AnswerPromptBuilder: + if not using_tool_calling_llm: + final_context_docs_response = next( + response + for response in tool_responses + if response.id == FINAL_CONTEXT_DOCUMENTS_ID + ) + final_context_documents = cast( + list[LlmDoc], final_context_docs_response.response + ) + else: + # if using tool calling llm, then the final context documents are the tool responses + final_context_documents = [] + + if answer_style_config.citation_config: + prompt_builder.update_system_prompt( + build_citations_system_message(prompt_config) + ) + prompt_builder.update_user_prompt( + build_citations_user_message( + message=prompt_builder.user_message_and_token_cnt[0], + prompt_config=prompt_config, + context_docs=final_context_documents, + all_doc_useful=( + answer_style_config.citation_config.all_docs_useful + if answer_style_config.citation_config + else False + ), + history_message=prompt_builder.single_message_history or "", + ) + ) + elif answer_style_config.quotes_config: + prompt_builder.update_user_prompt( + build_quotes_user_message( + message=prompt_builder.user_message_and_token_cnt[0], + context_docs=final_context_documents, + history_str=prompt_builder.single_message_history or "", + prompt=prompt_config, + ) + ) + + if using_tool_calling_llm: + prompt_builder.append_message(tool_call_summary.tool_call_request) + prompt_builder.append_message(tool_call_summary.tool_call_result) + + return prompt_builder diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index 58b94bdb0c8..fb3eb8b9932 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -6,8 +6,8 @@ from danswer.llm.interfaces import LLM from danswer.tools.models import ToolCallFinalResult from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse from danswer.tools.tool import Tool -from danswer.tools.tool import ToolResponse from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/ee/danswer/server/query_and_chat/utils.py b/backend/ee/danswer/server/query_and_chat/utils.py index a2f7253517a..be5507b01c2 100644 --- a/backend/ee/danswer/server/query_and_chat/utils.py +++ b/backend/ee/danswer/server/query_and_chat/utils.py @@ -12,7 +12,7 @@ from danswer.db.models import User from danswer.db.persona import get_prompts_by_ids from danswer.one_shot_answer.models import PersonaConfig -from danswer.tools.custom.custom_tool import ( +from danswer.tools.tool_implementations.custom.custom_tool import ( build_custom_tools_from_openapi_schema_and_headers, ) diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 10d1950ae03..0ed40c758d0 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -142,6 +142,9 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> assert response.status_code == 200 response_json = response.json() + # make sure there is an answer + assert response_json["answer"] + # since we only gave it one search doc, all responses should only contain that doc assert response_json["final_context_doc_indices"] == [0] assert response_json["llm_selected_doc_indices"] == [0] diff --git a/backend/tests/unit/danswer/llm/answering/conftest.py b/backend/tests/unit/danswer/llm/answering/conftest.py new file mode 100644 index 00000000000..a0077b53917 --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/conftest.py @@ -0,0 +1,113 @@ +import json +from datetime import datetime +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import SystemMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.constants import DocumentSource +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.llm.interfaces import LLMConfig +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.search.search_tool import SearchTool +from danswer.tools.tool_implementations.search_like_tool_utils import ( + FINAL_CONTEXT_DOCUMENTS_ID, +) + +QUERY = "Test question" +DEFAULT_SEARCH_ARGS = {"query": "search"} + + +@pytest.fixture +def answer_style_config() -> AnswerStyleConfig: + return AnswerStyleConfig(citation_config=CitationConfig()) + + +@pytest.fixture +def prompt_config() -> PromptConfig: + return PromptConfig( + system_prompt="System prompt", + task_prompt="Task prompt", + datetime_aware=False, + include_citations=True, + ) + + +@pytest.fixture +def mock_llm() -> MagicMock: + mock_llm_obj = MagicMock() + mock_llm_obj.config = LLMConfig( + model_provider="openai", + model_name="gpt-4o", + temperature=0.0, + api_key=None, + api_base=None, + api_version=None, + ) + return mock_llm_obj + + +@pytest.fixture +def mock_search_results() -> list[LlmDoc]: + return [ + LlmDoc( + content="Search result 1", + source_type=DocumentSource.WEB, + metadata={"id": "doc1"}, + document_id="doc1", + blurb="Blurb 1", + semantic_identifier="Semantic ID 1", + updated_at=datetime(2023, 1, 1), + link="https://example.com/doc1", + source_links={0: "https://example.com/doc1"}, + ), + LlmDoc( + content="Search result 2", + source_type=DocumentSource.WEB, + metadata={"id": "doc2"}, + document_id="doc2", + blurb="Blurb 2", + semantic_identifier="Semantic ID 2", + updated_at=datetime(2023, 1, 2), + link="https://example.com/doc2", + source_links={0: "https://example.com/doc2"}, + ), + ] + + +@pytest.fixture +def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock: + mock_tool = MagicMock(spec=SearchTool) + mock_tool.name = "search" + mock_tool.build_tool_message_content.return_value = "search_response" + mock_tool.get_args_for_non_tool_calling_llm.return_value = DEFAULT_SEARCH_ARGS + mock_tool.final_result.return_value = [ + json.loads(doc.model_dump_json()) for doc in mock_search_results + ] + mock_tool.run.return_value = [ + ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results) + ] + mock_tool.tool_definition.return_value = { + "type": "function", + "function": { + "name": "search", + "description": "Search for information", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The search query"}, + }, + "required": ["query"], + }, + }, + } + mock_post_search_tool_prompt_builder = MagicMock(spec=AnswerPromptBuilder) + mock_post_search_tool_prompt_builder.build.return_value = [ + SystemMessage(content="Updated system prompt"), + ] + mock_tool.build_next_prompt.return_value = mock_post_search_tool_prompt_builder + return mock_tool diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py index 12e3254d6d6..e6a5fe1f027 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -7,7 +7,7 @@ from danswer.chat.models import LlmDoc from danswer.configs.constants import DocumentSource from danswer.llm.answering.stream_processing.citation_processing import ( - extract_citations_from_stream, + CitationProcessor, ) from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping @@ -70,14 +70,16 @@ def process_text( ) -> tuple[str, list[CitationInfo]]: mock_docs, mock_doc_id_to_rank_map = mock_data mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) - result = list( - extract_citations_from_stream( - tokens=iter(tokens), - context_docs=mock_docs, - doc_id_to_rank_map=mapping, - stop_stream=None, - ) + processor = CitationProcessor( + context_docs=mock_docs, + doc_id_to_rank_map=mapping, + stop_stream=None, ) + result: list[DanswerAnswerPiece | CitationInfo] = [] + for token in tokens: + result.extend(processor.process_token(token)) + result.extend(processor.process_token(None)) + final_answer_text = "" citations = [] for piece in result: diff --git a/backend/tests/unit/danswer/llm/answering/test_answer.py b/backend/tests/unit/danswer/llm/answering/test_answer.py new file mode 100644 index 00000000000..95dde97f1f4 --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/test_answer.py @@ -0,0 +1,422 @@ +import json +from typing import cast +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from langchain_core.messages import ToolCall +from langchain_core.messages import ToolCallChunk + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import PromptConfig +from danswer.llm.interfaces import LLM +from danswer.tools.force import ForceUseTool +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff +from danswer.tools.models import ToolResponse +from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS +from tests.unit.danswer.llm.answering.conftest import QUERY + + +@pytest.fixture +def answer_instance( + mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig +) -> Answer: + return Answer( + question=QUERY, + answer_style_config=answer_style_config, + llm=mock_llm, + prompt_config=prompt_config, + force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None), + ) + + +def test_basic_answer(answer_instance: Answer) -> None: + mock_llm = cast(Mock, answer_instance.llm) + mock_llm.stream.return_value = [ + AIMessageChunk(content="This is a "), + AIMessageChunk(content="mock answer."), + ] + + output = list(answer_instance.processed_streamed_output) + assert len(output) == 2 + assert isinstance(output[0], DanswerAnswerPiece) + assert isinstance(output[1], DanswerAnswerPiece) + + full_answer = "".join( + piece.answer_piece + for piece in output + if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None + ) + assert full_answer == "This is a mock answer." + + assert answer_instance.llm_answer == "This is a mock answer." + assert answer_instance.citations == [] + + assert mock_llm.stream.call_count == 1 + mock_llm.stream.assert_called_once_with( + prompt=[ + SystemMessage(content="System prompt"), + HumanMessage(content="Task prompt\n\nQUERY:\nTest question"), + ], + tools=None, + tool_choice=None, + structured_response_format=None, + ) + + +@pytest.mark.parametrize( + "force_use_tool, expected_tool_args", + [ + ( + ForceUseTool(force_use=False, tool_name="", args=None), + DEFAULT_SEARCH_ARGS, + ), + ( + ForceUseTool( + force_use=True, tool_name="search", args={"query": "forced search"} + ), + {"query": "forced search"}, + ), + ], +) +def test_answer_with_search_call( + answer_instance: Answer, + mock_search_results: list[LlmDoc], + mock_search_tool: MagicMock, + force_use_tool: ForceUseTool, + expected_tool_args: dict, +) -> None: + answer_instance.tools = [mock_search_tool] + answer_instance.force_use_tool = force_use_tool + + # Set up the LLM mock to return search results and then an answer + mock_llm = cast(Mock, answer_instance.llm) + + stream_side_effect: list[list[BaseMessage]] = [] + + if not force_use_tool.force_use: + tool_call_chunk = AIMessageChunk(content="") + tool_call_chunk.tool_calls = [ + ToolCall( + id="search", + name="search", + args=expected_tool_args, + ) + ] + tool_call_chunk.tool_call_chunks = [ + ToolCallChunk( + id="search", + name="search", + args=json.dumps(expected_tool_args), + index=0, + ) + ] + stream_side_effect.append([tool_call_chunk]) + + stream_side_effect.append( + [ + AIMessageChunk(content="Based on the search results, "), + AIMessageChunk(content="the answer is abc[1]. "), + AIMessageChunk(content="This is some other stuff."), + ], + ) + mock_llm.stream.side_effect = stream_side_effect + + # Process the output + output = list(answer_instance.processed_streamed_output) + print(output) + + # Updated assertions + assert len(output) == 7 + assert output[0] == ToolCallKickoff( + tool_name="search", tool_args=expected_tool_args + ) + assert output[1] == ToolResponse( + id="final_context_documents", + response=mock_search_results, + ) + assert output[2] == ToolCallFinalResult( + tool_name="search", + tool_args=expected_tool_args, + tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], + ) + assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") + expected_citation = CitationInfo(citation_num=1, document_id="doc1") + assert output[4] == expected_citation + assert output[5] == DanswerAnswerPiece( + answer_piece="the answer is abc[[1]](https://example.com/doc1). " + ) + assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") + + expected_answer = ( + "Based on the search results, " + "the answer is abc[[1]](https://example.com/doc1). " + "This is some other stuff." + ) + full_answer = "".join( + piece.answer_piece + for piece in output + if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None + ) + assert full_answer == expected_answer + + assert answer_instance.llm_answer == expected_answer + assert len(answer_instance.citations) == 1 + assert answer_instance.citations[0] == expected_citation + + # Verify LLM calls + if not force_use_tool.force_use: + assert mock_llm.stream.call_count == 2 + first_call, second_call = mock_llm.stream.call_args_list + + # First call should include the search tool definition + assert len(first_call.kwargs["tools"]) == 1 + assert ( + first_call.kwargs["tools"][0] + == mock_search_tool.tool_definition.return_value + ) + + # Second call should not include tools (as we're just generating the final answer) + assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"] + # Second call should use the returned prompt from build_next_prompt + assert ( + second_call.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + # Verify that tool_definition was called on the mock_search_tool + mock_search_tool.tool_definition.assert_called_once() + else: + assert mock_llm.stream.call_count == 1 + + call = mock_llm.stream.call_args_list[0] + assert ( + call.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + +def test_answer_with_search_no_tool_calling( + answer_instance: Answer, + mock_search_results: list[LlmDoc], + mock_search_tool: MagicMock, +) -> None: + answer_instance.tools = [mock_search_tool] + + # Set up the LLM mock to return an answer + mock_llm = cast(Mock, answer_instance.llm) + mock_llm.stream.return_value = [ + AIMessageChunk(content="Based on the search results, "), + AIMessageChunk(content="the answer is abc[1]. "), + AIMessageChunk(content="This is some other stuff."), + ] + + # Force non-tool calling behavior + answer_instance.using_tool_calling_llm = False + + # Process the output + output = list(answer_instance.processed_streamed_output) + + # Assertions + assert len(output) == 7 + assert output[0] == ToolCallKickoff( + tool_name="search", tool_args=DEFAULT_SEARCH_ARGS + ) + assert output[1] == ToolResponse( + id="final_context_documents", + response=mock_search_results, + ) + assert output[2] == ToolCallFinalResult( + tool_name="search", + tool_args=DEFAULT_SEARCH_ARGS, + tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], + ) + assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") + expected_citation = CitationInfo(citation_num=1, document_id="doc1") + assert output[4] == expected_citation + assert output[5] == DanswerAnswerPiece( + answer_piece="the answer is abc[[1]](https://example.com/doc1). " + ) + assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") + + expected_answer = ( + "Based on the search results, " + "the answer is abc[[1]](https://example.com/doc1). " + "This is some other stuff." + ) + assert answer_instance.llm_answer == expected_answer + assert len(answer_instance.citations) == 1 + assert answer_instance.citations[0] == expected_citation + + # Verify LLM calls + assert mock_llm.stream.call_count == 1 + call_args = mock_llm.stream.call_args + + # Verify that no tools were passed to the LLM + assert "tools" not in call_args.kwargs or not call_args.kwargs["tools"] + + # Verify that the prompt was built correctly + assert ( + call_args.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + # Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool + mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with( + f"Task prompt\n\nQUERY:\n{QUERY}", [], answer_instance.llm + ) + + # Verify that the search tool's run method was called + mock_search_tool.run.assert_called_once() + + +def test_answer_with_search_call_quotes_enabled( + answer_instance: Answer, + mock_search_results: list[LlmDoc], + mock_search_tool: MagicMock, +) -> None: + answer_instance.tools = [mock_search_tool] + answer_instance.force_use_tool = ForceUseTool( + force_use=False, tool_name="", args=None + ) + answer_instance.answer_style_config.citation_config = CitationConfig( + use_quotes=True + ) + + # Set up the LLM mock to return search results and then an answer + mock_llm = cast(Mock, answer_instance.llm) + + tool_call_chunk = AIMessageChunk(content="") + tool_call_chunk.tool_calls = [ + ToolCall( + id="search", + name="search", + args=DEFAULT_SEARCH_ARGS, + ) + ] + tool_call_chunk.tool_call_chunks = [ + ToolCallChunk( + id="search", + name="search", + args=json.dumps(DEFAULT_SEARCH_ARGS), + index=0, + ) + ] + + mock_llm.stream.side_effect = [ + [tool_call_chunk], + [ + AIMessageChunk(content="Answer"), + ], + ] + + # Process the output + output = list(answer_instance.processed_streamed_output) + + # Assertions + assert len(output) == 7 + assert output[0] == ToolCallKickoff( + tool_name="search", tool_args=DEFAULT_SEARCH_ARGS + ) + assert output[1] == ToolResponse( + id="final_context_documents", + response=mock_search_results, + ) + assert output[2] == ToolCallFinalResult( + tool_name="search", + tool_args=DEFAULT_SEARCH_ARGS, + tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], + ) + assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") + expected_citation = CitationInfo(citation_num=1, document_id="doc1") + assert output[4] == expected_citation + assert output[5] == DanswerAnswerPiece( + answer_piece='the answer is "abc"[[1]](https://example.com/doc1). ' + ) + assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") + + expected_answer = ( + "Based on the search results, " + 'the answer is "abc"[[1]](https://example.com/doc1). ' + "This is some other stuff." + ) + full_answer = "".join( + piece.answer_piece + for piece in output + if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None + ) + assert full_answer == expected_answer + + assert answer_instance.llm_answer == expected_answer + assert len(answer_instance.citations) == 1 + assert answer_instance.citations[0] == expected_citation + + # Verify LLM calls + assert mock_llm.stream.call_count == 2 + first_call, second_call = mock_llm.stream.call_args_list + + # First call should include the search tool definition + assert len(first_call.kwargs["tools"]) == 1 + assert ( + first_call.kwargs["tools"][0] == mock_search_tool.tool_definition.return_value + ) + + # Second call should not include tools (as we're just generating the final answer) + assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"] + # Second call should use the returned prompt from build_next_prompt + assert ( + second_call.kwargs["prompt"] + == mock_search_tool.build_next_prompt.return_value.build.return_value + ) + + # Verify that tool_definition was called on the mock_search_tool + mock_search_tool.tool_definition.assert_called_once() + + +def test_is_cancelled(answer_instance: Answer) -> None: + # Set up the LLM mock to return multiple chunks + mock_llm = Mock() + answer_instance.llm = mock_llm + mock_llm.stream.return_value = [ + AIMessageChunk(content="This is the "), + AIMessageChunk(content="first part."), + AIMessageChunk(content="This should not be seen."), + ] + + # Create a mutable object to control is_connected behavior + connection_status = {"connected": True} + answer_instance.is_connected = lambda: connection_status["connected"] + + # Process the output + output = [] + for i, chunk in enumerate(answer_instance.processed_streamed_output): + output.append(chunk) + # Simulate disconnection after the second chunk + if i == 1: + connection_status["connected"] = False + + assert len(output) == 3 + assert output[0] == DanswerAnswerPiece(answer_piece="This is the ") + assert output[1] == DanswerAnswerPiece(answer_piece="first part.") + assert output[2] == StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + + # Verify that the stream was cancelled + assert answer_instance.is_cancelled() is True + + # Verify that the final answer only contains the streamed parts + assert answer_instance.llm_answer == "This is the first part." + + # Verify LLM calls + mock_llm.stream.assert_called_once() diff --git a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py b/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py index 998b2932cbb..7bd4a498bd7 100644 --- a/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py +++ b/backend/tests/unit/danswer/llm/answering/test_skip_gen_ai.py @@ -6,8 +6,11 @@ from pytest_mock import MockerFixture from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PromptConfig from danswer.one_shot_answer.answer_question import AnswerObjectIterator from danswer.tools.force import ForceUseTool +from danswer.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -24,39 +27,43 @@ }, ], ) -def test_skip_gen_ai_answer_generation_flag(config: dict[str, Any]) -> None: - search_tool = Mock() - search_tool.name = "search" - search_tool.run = Mock() - search_tool.run.return_value = [Mock()] +def test_skip_gen_ai_answer_generation_flag( + config: dict[str, Any], + mock_search_tool: SearchTool, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, +) -> None: + question = config["question"] + skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"] + mock_llm = Mock() mock_llm.config = Mock() mock_llm.config.model_name = "gpt-4o-mini" mock_llm.stream = Mock() mock_llm.stream.return_value = [Mock()] answer = Answer( - question=config["question"], - answer_style_config=Mock(), - prompt_config=Mock(), + question=question, + answer_style_config=answer_style_config, + prompt_config=prompt_config, llm=mock_llm, single_message_history="history", - tools=[search_tool], + tools=[mock_search_tool], force_use_tool=( ForceUseTool( - tool_name=search_tool.name, - args={"query": config["question"]}, + tool_name=mock_search_tool.name, + args={"query": question}, force_use=True, ) ), skip_explicit_tool_calling=True, return_contexts=True, - skip_gen_ai_answer_generation=config["skip_gen_ai_answer_generation"], + skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, ) count = 0 for _ in cast(AnswerObjectIterator, answer.processed_streamed_output): count += 1 - assert count == 2 - if not config["skip_gen_ai_answer_generation"]: + assert count == 3 if skip_gen_ai_answer_generation else 4 + if not skip_gen_ai_answer_generation: mock_llm.stream.assert_called_once() else: mock_llm.stream.assert_not_called() diff --git a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py index 6139f41e62a..f56336809e4 100644 --- a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py +++ b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py @@ -5,14 +5,18 @@ import pytest -from danswer.tools.custom.custom_tool import ( +from danswer.tools.models import DynamicSchemaInfo +from danswer.tools.models import ToolResponse +from danswer.tools.tool_implementations.custom.custom_tool import ( build_custom_tools_from_openapi_schema_and_headers, ) -from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID -from danswer.tools.custom.custom_tool import CustomToolCallSummary -from danswer.tools.custom.custom_tool import validate_openapi_schema -from danswer.tools.models import DynamicSchemaInfo -from danswer.tools.tool import ToolResponse +from danswer.tools.tool_implementations.custom.custom_tool import ( + CUSTOM_TOOL_RESPONSE_ID, +) +from danswer.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary +from danswer.tools.tool_implementations.custom.custom_tool import ( + validate_openapi_schema, +) from danswer.utils.headers import HeaderItemDict @@ -78,7 +82,7 @@ def setUp(self) -> None: chat_session_id=uuid.uuid4(), message_id=20 ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_run_get(self, mock_request: unittest.mock.MagicMock) -> None: """ Test the GET method of a custom tool. @@ -106,7 +110,7 @@ def test_custom_tool_run_get(self, mock_request: unittest.mock.MagicMock) -> Non "Tool name in response does not match expected value", ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_run_post(self, mock_request: unittest.mock.MagicMock) -> None: """ Test the POST method of a custom tool. @@ -136,7 +140,7 @@ def test_custom_tool_run_post(self, mock_request: unittest.mock.MagicMock) -> No "Tool name in response does not match expected value", ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_with_headers( self, mock_request: unittest.mock.MagicMock ) -> None: @@ -164,7 +168,7 @@ def test_custom_tool_with_headers( "GET", expected_url, json=None, headers=expected_headers ) - @patch("danswer.tools.custom.custom_tool.requests.request") + @patch("danswer.tools.tool_implementations.custom.custom_tool.requests.request") def test_custom_tool_with_empty_headers( self, mock_request: unittest.mock.MagicMock ) -> None: From 200bb96853d6d96a99093f6e915fe9721ab5c6b3 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 31 Oct 2024 13:46:00 -0700 Subject: [PATCH 02/10] Add quote support --- backend/danswer/llm/answering/answer.py | 24 +- .../llm/answering/llm_response_handler.py | 2 +- ..._handler.py => answer_response_handler.py} | 30 ++ .../stream_processing/quotes_processing.py | 164 +++---- .../test_quote_processing.py | 447 +++++++++--------- .../unit/danswer/llm/answering/test_answer.py | 78 ++- 6 files changed, 383 insertions(+), 362 deletions(-) rename backend/danswer/llm/answering/stream_processing/{citation_response_handler.py => answer_response_handler.py} (69%) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 0aea52c303b..9b48412e44e 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -18,12 +18,15 @@ from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.build import default_build_system_message from danswer.llm.answering.prompts.build import default_build_user_message -from danswer.llm.answering.stream_processing.citation_response_handler import ( +from danswer.llm.answering.stream_processing.answer_response_handler import ( CitationResponseHandler, ) -from danswer.llm.answering.stream_processing.citation_response_handler import ( +from danswer.llm.answering.stream_processing.answer_response_handler import ( DummyAnswerResponseHandler, ) +from danswer.llm.answering.stream_processing.answer_response_handler import ( + QuotesResponseHandler, +) from danswer.llm.answering.stream_processing.utils import map_document_id_order from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler from danswer.llm.interfaces import LLM @@ -207,13 +210,20 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: tool_call_handler = ToolResponseHandler(current_llm_call.tools) search_result = SearchTool.get_search_result(current_llm_call) or [] - citation_response_handler = CitationResponseHandler( - context_docs=search_result, - doc_id_to_rank_map=map_document_id_order(search_result), - ) + if self.answer_style_config.citation_config: + answer_handler = CitationResponseHandler( + context_docs=search_result, + doc_id_to_rank_map=map_document_id_order(search_result), + ) + elif self.answer_style_config.quotes_config: + answer_handler = QuotesResponseHandler( + context_docs=search_result, + ) + else: + raise ValueError("No answer style config provided") response_handler_manager = LLMResponseHandlerManager( - tool_call_handler, citation_response_handler, self.is_cancelled + tool_call_handler, answer_handler, self.is_cancelled ) # DEBUG: good breakpoint diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py index 6578e808952..b013e9cf892 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: - from danswer.llm.answering.stream_processing.citation_response_handler import ( + from danswer.llm.answering.stream_processing.answer_response_handler import ( AnswerResponseHandler, ) from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler diff --git a/backend/danswer/llm/answering/stream_processing/citation_response_handler.py b/backend/danswer/llm/answering/stream_processing/answer_response_handler.py similarity index 69% rename from backend/danswer/llm/answering/stream_processing/citation_response_handler.py rename to backend/danswer/llm/answering/stream_processing/answer_response_handler.py index 07a342fbd7e..80a1446163b 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_response_handler.py +++ b/backend/danswer/llm/answering/stream_processing/answer_response_handler.py @@ -9,6 +9,9 @@ from danswer.llm.answering.stream_processing.citation_processing import ( CitationProcessor, ) +from danswer.llm.answering.stream_processing.quotes_processing import ( + QuotesProcessor, +) from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping @@ -59,3 +62,30 @@ def handle_response_part( # Process the new content through the citation processor yield from self.citation_processor.process_token(content) + + +class QuotesResponseHandler(AnswerResponseHandler): + def __init__( + self, + context_docs: list[LlmDoc], + is_json_prompt: bool = True, + ): + self.quotes_processor = QuotesProcessor( + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + def handle_response_part( + self, + response_item: BaseMessage | None, + previous_response_items: list[BaseMessage], + ) -> Generator[ResponsePart, None, None]: + if response_item is None: + yield from self.quotes_processor.process_token(None) + return + + content = ( + response_item.content if isinstance(response_item.content, str) else "" + ) + + yield from self.quotes_processor.process_token(content) diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py index 501a56b5aa7..e152463aef5 100644 --- a/backend/danswer/llm/answering/stream_processing/quotes_processing.py +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -1,14 +1,11 @@ import math import re -from collections.abc import Callable from collections.abc import Generator -from collections.abc import Iterator from json import JSONDecodeError from typing import Optional import regex -from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import DanswerAnswer from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerQuote @@ -157,7 +154,7 @@ def separate_answer_quotes( return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) -def process_answer( +def _process_answer( answer_raw: str, docs: list[LlmDoc], is_json_prompt: bool = True, @@ -195,7 +192,7 @@ def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: def _extract_quotes_from_completed_token_stream( model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True ) -> DanswerQuotes: - answer, quotes = process_answer(model_output, context_docs, is_json_prompt) + answer, quotes = _process_answer(model_output, context_docs, is_json_prompt) if answer: logger.notice(answer) elif model_output: @@ -204,94 +201,101 @@ def _extract_quotes_from_completed_token_stream( return quotes -def process_model_tokens( - tokens: Iterator[str], - context_docs: list[LlmDoc], - is_json_prompt: bool = True, -) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: - """Used in the streaming case to process the model output - into an Answer and Quotes - - Yields Answer tokens back out in a dict for streaming to frontend - When Answer section ends, yields dict with answer_finished key - Collects all the tokens at the end to form the complete model output""" - quote_pat = f"\n{QUOTE_PAT}" - # Sometimes worse model outputs new line instead of : - quote_loose = f"\n{quote_pat[:-1]}\n" - # Sometime model outputs two newlines before quote section - quote_pat_full = f"\n{quote_pat}" - model_output: str = "" - found_answer_start = False if is_json_prompt else True - found_answer_end = False - hold_quote = "" - - for token in tokens: - model_previous = model_output - model_output += token - - if not found_answer_start: - m = answer_pattern.search(model_output) +class QuotesProcessor: + def __init__( + self, + context_docs: list[LlmDoc], + is_json_prompt: bool = True, + ): + self.context_docs = context_docs + self.is_json_prompt = is_json_prompt + + self.found_answer_start = False if is_json_prompt else True + self.found_answer_end = False + self.hold_quote = "" + self.model_output = "" + self.hold = "" + + def process_token( + self, token: str | None + ) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: + # None -> end of stream + if token is None: + if self.model_output: + yield _extract_quotes_from_completed_token_stream( + model_output=self.model_output, + context_docs=self.context_docs, + is_json_prompt=self.is_json_prompt, + ) + return + + model_previous = self.model_output + self.model_output += token + + if not self.found_answer_start: + m = answer_pattern.search(self.model_output) if m: - found_answer_start = True + self.found_answer_start = True - # Prevent heavy cases of hallucinations where model is never providing a JSON - # We want to quickly update the user - not stream forever - if is_json_prompt and len(model_output) > 70: + # Prevent heavy cases of hallucinations + if self.is_json_prompt and len(self.model_output) > 70: logger.warning("LLM did not produce json as prompted") - found_answer_end = True - continue - - remaining = model_output[m.end() :] + self.found_answer_end = True + return + + remaining = self.model_output[m.end() :] + + # Look for an unescaped quote, which means the answer is entirely contained + # in this toekn e.g. if the token is `{"answer": "blah", "qu` + quote_indices = [i for i, char in enumerate(remaining) if char == '"'] + for quote_idx in quote_indices: + # Check if quote is escaped by counting backslashes before it + num_backslashes = 0 + pos = quote_idx - 1 + while pos >= 0 and remaining[pos] == "\\": + num_backslashes += 1 + pos -= 1 + # If even number of backslashes, quote is not escaped + if num_backslashes % 2 == 0: + yield DanswerAnswerPiece(answer_piece=remaining[:quote_idx]) + return + + # If no unescaped quote found, yield the remaining string if len(remaining) > 0: yield DanswerAnswerPiece(answer_piece=remaining) - continue + return - if found_answer_start and not found_answer_end: - if is_json_prompt and _stream_json_answer_end(model_previous, token): - found_answer_end = True + if self.found_answer_start and not self.found_answer_end: + if self.is_json_prompt and _stream_json_answer_end(model_previous, token): + self.found_answer_end = True - # return the remaining part of the answer e.g. token might be 'd.", ' and we should yield 'd.' if token: try: answer_token_section = token.index('"') yield DanswerAnswerPiece( - answer_piece=hold_quote + token[:answer_token_section] + answer_piece=self.hold_quote + token[:answer_token_section] ) except ValueError: logger.error("Quotation mark not found in token") - yield DanswerAnswerPiece(answer_piece=hold_quote + token) + yield DanswerAnswerPiece(answer_piece=self.hold_quote + token) yield DanswerAnswerPiece(answer_piece=None) - continue - elif not is_json_prompt: - if quote_pat in hold_quote + token or quote_loose in hold_quote + token: - found_answer_end = True + return + + elif not self.is_json_prompt: + quote_pat = f"\n{QUOTE_PAT}" + quote_loose = f"\n{quote_pat[:-1]}\n" + quote_pat_full = f"\n{quote_pat}" + + if ( + quote_pat in self.hold_quote + token + or quote_loose in self.hold_quote + token + ): + self.found_answer_end = True yield DanswerAnswerPiece(answer_piece=None) - continue - if hold_quote + token in quote_pat_full: - hold_quote += token - continue - yield DanswerAnswerPiece(answer_piece=hold_quote + token) - hold_quote = "" - - logger.debug(f"Raw Model QnA Output: {model_output}") - - yield _extract_quotes_from_completed_token_stream( - model_output=model_output, - context_docs=context_docs, - is_json_prompt=is_json_prompt, - ) - - -def build_quotes_processor( - context_docs: list[LlmDoc], is_json_prompt: bool -) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: - def stream_processor( - tokens: Iterator[str], - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_docs, - is_json_prompt=is_json_prompt, - ) - - return stream_processor + return + if self.hold_quote + token in quote_pat_full: + self.hold_quote += token + return + + yield DanswerAnswerPiece(answer_piece=self.hold_quote + token) + self.hold_quote = "" diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py index e80c5c4f657..c154c5a5b0c 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_quote_processing.py @@ -6,7 +6,7 @@ from danswer.chat.models import LlmDoc from danswer.configs.constants import DocumentSource from danswer.llm.answering.stream_processing.quotes_processing import ( - process_model_tokens, + QuotesProcessor, ) mock_docs = [ @@ -25,179 +25,202 @@ ] -tokens_with_quotes = [ - "{", - "\n ", - '"answer": "Yes', - ", Danswer allows", - " customized prompts. This", - " feature", - " is currently being", - " developed and implemente", - "d to", - " improve", - " the accuracy", - " of", - " Language", - " Models (", - "LL", - "Ms) for", - " different", - " companies", - ".", - " The custom", - "ized prompts feature", - " woul", - "d allow users to ad", - "d person", - "alized prom", - "pts through", - " an", - " interface or", - " metho", - "d,", - " which would then be used to", - " train", - " the LLM.", - " This enhancement", - " aims to make", - " Danswer more", - " adaptable to", - " different", - " business", - " contexts", - " by", - " tail", - "oring it", - " to the specific language", - " an", - "d terminology", - " used within", - " a", - " company.", - " Additionally", - ",", - " Danswer already", - " supports creating", - " custom AI", - " Assistants with", - " different", - " prom", - "pts and backing", - " knowledge", - " sets", - ",", - " which", - " is", - " a form", - " of prompt", - " customization. However, it", - "'s important to nLogging Details LiteLLM-Success Call: Noneote that some", - " aspects", - " of prompt", - " customization,", - " such as for", - " Sl", - "ack", - "b", - "ots, may", - " still", - " be in", - " development or have", - ' limitations.",', - '\n "quotes": [', - '\n "We', - " woul", - "d like to ad", - "d customized prompts for", - " different", - " companies to improve the accuracy of", - " Language", - " Model", - " (LLM)", - '.",\n "A', - " new", - " feature that", - " allows users to add personalize", - "d prompts.", - " This would involve", - " creating", - " an interface or method for", - " users to input", - " their", - " own", - " prom", - "pts,", - " which would then be used to", - ' train the LLM.",', - '\n "Create', - " custom AI Assistants with", - " different prompts and backing knowledge", - ' sets.",', - '\n "This', - " PR", - " fixes", - " https", - "://github.com/dan", - "swer-ai/dan", - "swer/issues/1", - "584", - " by", - " setting", - " the system", - " default", - " prompt for", - " sl", - "ackbots const", - "rained by", - " ", - "document sets", - ".", - " It", - " probably", - " isn", - "'t ideal", - " -", - " it", - " might", - " be pref", - "erable to be", - " able to select", - " a prompt for", - " the", - " slackbot from", - " the", - " admin", - " panel", - " -", - " but it sol", - "ves the immediate problem", - " of", - " the slack", - " listener", - " cr", - "ashing when", - " configure", - "d this", - ' way."\n ]', - "\n}", - "", -] +def _process_tokens( + processor: QuotesProcessor, tokens: list[str] +) -> tuple[str, list[str]]: + """Process a list of tokens and return the answer and quotes. + + Args: + processor: QuotesProcessor instance + tokens: List of tokens to process + + Returns: + Tuple of (answer_text, list_of_quotes) + """ + answer = "" + quotes: list[str] = [] + + # need to add a None to the end to simulate the end of the stream + for token in tokens + [None]: + for output in processor.process_token(token): + if isinstance(output, DanswerAnswerPiece): + if output.answer_piece: + answer += output.answer_piece + elif isinstance(output, DanswerQuotes): + quotes.extend(q.quote for q in output.quotes) + + return answer, quotes def test_process_model_tokens_answer() -> None: - gen = process_model_tokens(tokens=iter(tokens_with_quotes), context_docs=mock_docs) + tokens_with_quotes = [ + "{", + "\n ", + '"answer": "Yes', + ", Danswer allows", + " customized prompts. This", + " feature", + " is currently being", + " developed and implemente", + "d to", + " improve", + " the accuracy", + " of", + " Language", + " Models (", + "LL", + "Ms) for", + " different", + " companies", + ".", + " The custom", + "ized prompts feature", + " woul", + "d allow users to ad", + "d person", + "alized prom", + "pts through", + " an", + " interface or", + " metho", + "d,", + " which would then be used to", + " train", + " the LLM.", + " This enhancement", + " aims to make", + " Danswer more", + " adaptable to", + " different", + " business", + " contexts", + " by", + " tail", + "oring it", + " to the specific language", + " an", + "d terminology", + " used within", + " a", + " company.", + " Additionally", + ",", + " Danswer already", + " supports creating", + " custom AI", + " Assistants with", + " different", + " prom", + "pts and backing", + " knowledge", + " sets", + ",", + " which", + " is", + " a form", + " of prompt", + " customization. However, it", + "'s important to nLogging Details LiteLLM-Success Call: Noneote that some", + " aspects", + " of prompt", + " customization,", + " such as for", + " Sl", + "ack", + "b", + "ots, may", + " still", + " be in", + " development or have", + ' limitations.",', + '\n "quotes": [', + '\n "We', + " woul", + "d like to ad", + "d customized prompts for", + " different", + " companies to improve the accuracy of", + " Language", + " Model", + " (LLM)", + '.",\n "A', + " new", + " feature that", + " allows users to add personalize", + "d prompts.", + " This would involve", + " creating", + " an interface or method for", + " users to input", + " their", + " own", + " prom", + "pts,", + " which would then be used to", + ' train the LLM.",', + '\n "Create', + " custom AI Assistants with", + " different prompts and backing knowledge", + ' sets.",', + '\n "This', + " PR", + " fixes", + " https", + "://github.com/dan", + "swer-ai/dan", + "swer/issues/1", + "584", + " by", + " setting", + " the system", + " default", + " prompt for", + " sl", + "ackbots const", + "rained by", + " ", + "document sets", + ".", + " It", + " probably", + " isn", + "'t ideal", + " -", + " it", + " might", + " be pref", + "erable to be", + " able to select", + " a prompt for", + " the", + " slackbot from", + " the", + " admin", + " panel", + " -", + " but it sol", + "ves the immediate problem", + " of", + " the slack", + " listener", + " cr", + "ashing when", + " configure", + "d this", + ' way."\n ]', + "\n}", + "", + ] + + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens_with_quotes) s_json = "".join(tokens_with_quotes) j = json.loads(s_json) expected_answer = j["answer"] - actual = "" - for o in gen: - if isinstance(o, DanswerAnswerPiece): - if o.answer_piece: - actual += o.answer_piece - - assert expected_answer == actual + assert expected_answer == answer + # NOTE: no quotes, since the docs don't match the quotes + assert len(quotes) == 0 def test_simple_json_answer() -> None: @@ -214,16 +237,11 @@ def test_simple_json_answer() -> None: "\n", "```", ] - gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs) - - expected_answer = "This is a simple answer." - actual = "".join( - o.answer_piece - for o in gen - if isinstance(o, DanswerAnswerPiece) and o.answer_piece - ) + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens) - assert expected_answer == actual + assert "This is a simple answer." == answer + assert len(quotes) == 0 def test_json_answer_with_quotes() -> None: @@ -242,16 +260,21 @@ def test_json_answer_with_quotes() -> None: "\n", "```", ] - gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs) + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens) + + assert "This is a split answer." == answer + assert len(quotes) == 0 - expected_answer = "This is a split answer." - actual = "".join( - o.answer_piece - for o in gen - if isinstance(o, DanswerAnswerPiece) and o.answer_piece - ) - assert expected_answer == actual +def test_json_answer_with_quotes_one_chunk() -> None: + tokens = ['```json\n{"answer": "z",\n"quotes": ["Document"]\n}\n```'] + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens) + + assert "z" == answer + assert len(quotes) == 1 + assert quotes[0] == "Document" def test_json_answer_split_tokens() -> None: @@ -271,16 +294,11 @@ def test_json_answer_split_tokens() -> None: "\n", "```", ] - gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs) - - expected_answer = "This is a split answer." - actual = "".join( - o.answer_piece - for o in gen - if isinstance(o, DanswerAnswerPiece) and o.answer_piece - ) + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens) - assert expected_answer == actual + assert "This is a split answer." == answer + assert len(quotes) == 0 def test_lengthy_prefixed_json_with_quotes() -> None: @@ -298,23 +316,12 @@ def test_lengthy_prefixed_json_with_quotes() -> None: "\n", "```", ] + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens) - gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs) - - actual_answer = "" - actual_count = 0 - for o in gen: - if isinstance(o, DanswerAnswerPiece): - if o.answer_piece: - actual_answer += o.answer_piece - continue - - if isinstance(o, DanswerQuotes): - for q in o.quotes: - assert q.quote == "Document" - actual_count += 1 - assert "This is a simple answer." == actual_answer - assert 1 == actual_count + assert "This is a simple answer." == answer + assert len(quotes) == 1 + assert quotes[0] == "Document" def test_prefixed_json_with_quotes() -> None: @@ -331,21 +338,9 @@ def test_prefixed_json_with_quotes() -> None: "\n", "```", ] + processor = QuotesProcessor(context_docs=mock_docs) + answer, quotes = _process_tokens(processor, tokens) - gen = process_model_tokens(tokens=iter(tokens), context_docs=mock_docs) - - actual_answer = "" - actual_count = 0 - for o in gen: - if isinstance(o, DanswerAnswerPiece): - if o.answer_piece: - actual_answer += o.answer_piece - continue - - if isinstance(o, DanswerQuotes): - for q in o.quotes: - assert q.quote == "Document" - actual_count += 1 - - assert "This is a simple answer." == actual_answer - assert 1 == actual_count + assert "This is a simple answer." == answer + assert len(quotes) == 1 + assert quotes[0] == "Document" diff --git a/backend/tests/unit/danswer/llm/answering/test_answer.py b/backend/tests/unit/danswer/llm/answering/test_answer.py index 95dde97f1f4..f772f157204 100644 --- a/backend/tests/unit/danswer/llm/answering/test_answer.py +++ b/backend/tests/unit/danswer/llm/answering/test_answer.py @@ -13,13 +13,15 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerQuote +from danswer.chat.models import DanswerQuotes from danswer.chat.models import LlmDoc from danswer.chat.models import StreamStopInfo from danswer.chat.models import StreamStopReason from danswer.llm.answering.answer import Answer from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.models import QuotesConfig from danswer.llm.interfaces import LLM from danswer.tools.force import ForceUseTool from danswer.tools.models import ToolCallFinalResult @@ -291,9 +293,7 @@ def test_answer_with_search_call_quotes_enabled( answer_instance.force_use_tool = ForceUseTool( force_use=False, tool_name="", args=None ) - answer_instance.answer_style_config.citation_config = CitationConfig( - use_quotes=True - ) + answer_instance.answer_style_config.quotes_config = QuotesConfig() # Set up the LLM mock to return search results and then an answer mock_llm = cast(Mock, answer_instance.llm) @@ -315,10 +315,21 @@ def test_answer_with_search_call_quotes_enabled( ) ] + # needs to be short due to the "anti-hallucination" check in QuotesProcessor + answer_content = "z" + quote_content = mock_search_results[0].content mock_llm.stream.side_effect = [ [tool_call_chunk], [ - AIMessageChunk(content="Answer"), + AIMessageChunk( + content=( + '{"answer": "' + + answer_content + + '", "quotes": ["' + + quote_content + + '"]}' + ) + ), ], ] @@ -326,7 +337,7 @@ def test_answer_with_search_call_quotes_enabled( output = list(answer_instance.processed_streamed_output) # Assertions - assert len(output) == 7 + assert len(output) == 5 assert output[0] == ToolCallKickoff( tool_name="search", tool_args=DEFAULT_SEARCH_ARGS ) @@ -339,50 +350,21 @@ def test_answer_with_search_call_quotes_enabled( tool_args=DEFAULT_SEARCH_ARGS, tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results], ) - assert output[3] == DanswerAnswerPiece(answer_piece="Based on the search results, ") - expected_citation = CitationInfo(citation_num=1, document_id="doc1") - assert output[4] == expected_citation - assert output[5] == DanswerAnswerPiece( - answer_piece='the answer is "abc"[[1]](https://example.com/doc1). ' - ) - assert output[6] == DanswerAnswerPiece(answer_piece="This is some other stuff.") - - expected_answer = ( - "Based on the search results, " - 'the answer is "abc"[[1]](https://example.com/doc1). ' - "This is some other stuff." - ) - full_answer = "".join( - piece.answer_piece - for piece in output - if isinstance(piece, DanswerAnswerPiece) and piece.answer_piece is not None - ) - assert full_answer == expected_answer - - assert answer_instance.llm_answer == expected_answer - assert len(answer_instance.citations) == 1 - assert answer_instance.citations[0] == expected_citation - - # Verify LLM calls - assert mock_llm.stream.call_count == 2 - first_call, second_call = mock_llm.stream.call_args_list - - # First call should include the search tool definition - assert len(first_call.kwargs["tools"]) == 1 - assert ( - first_call.kwargs["tools"][0] == mock_search_tool.tool_definition.return_value - ) - - # Second call should not include tools (as we're just generating the final answer) - assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"] - # Second call should use the returned prompt from build_next_prompt - assert ( - second_call.kwargs["prompt"] - == mock_search_tool.build_next_prompt.return_value.build.return_value + assert output[3] == DanswerAnswerPiece(answer_piece=answer_content) + assert output[4] == DanswerQuotes( + quotes=[ + DanswerQuote( + quote=quote_content, + document_id=mock_search_results[0].document_id, + link=mock_search_results[0].link, + source_type=mock_search_results[0].source_type, + semantic_identifier=mock_search_results[0].semantic_identifier, + blurb=mock_search_results[0].blurb, + ) + ] ) - # Verify that tool_definition was called on the mock_search_tool - mock_search_tool.tool_definition.assert_called_once() + assert answer_instance.llm_answer == answer_content def test_is_cancelled(answer_instance: Answer) -> None: From 3cd4ed5052277428dc06343f53e0e6486af26208 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 31 Oct 2024 14:35:54 -0700 Subject: [PATCH 03/10] Testing --- backend/danswer/chat/process_message.py | 5 ++++- .../llm/answering/llm_response_handler.py | 3 +-- .../server/query_and_chat/chat_backend.py | 17 ++++++++++------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 4ff30dd3c04..bf058d376b6 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -18,6 +18,7 @@ from danswer.chat.models import MessageSpecificCitations from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError +from danswer.chat.models import StreamStopInfo from danswer.configs.app_configs import AZURE_DALLE_API_BASE from danswer.configs.app_configs import AZURE_DALLE_API_KEY from danswer.configs.app_configs import AZURE_DALLE_API_VERSION @@ -278,6 +279,7 @@ def _get_force_search_settings( | CustomToolResponse | MessageSpecificCitations | MessageResponseIDInfo + | StreamStopInfo ) ChatPacketStream = Iterator[ChatPacket] @@ -803,7 +805,8 @@ def stream_chat_message_objects( response=custom_tool_response.tool_result, tool_name=custom_tool_response.tool_name, ) - + elif isinstance(packet, StreamStopInfo): + pass else: if isinstance(packet, ToolCallFinalResult): tool_result = packet diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py index b013e9cf892..9c5d8234ffd 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -9,7 +9,6 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import StreamStopInfo -from danswer.chat.models import StreamStopReason from danswer.file_store.models import InMemoryChatFile from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.tools.force import ForceUseTool @@ -72,7 +71,7 @@ def handle_llm_response( all_messages.append(message) if self.is_cancelled(): - yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + # yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) return # potentially give back all info on the selected tool call + its result diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index ac64231dcf2..a3f93287921 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -283,13 +283,16 @@ def delete_chat_session_by_id( raise HTTPException(status_code=400, detail=str(e)) -async def is_disconnected(request: Request) -> Callable[[], bool]: +async def is_connected(request: Request) -> Callable[[], bool]: main_loop = asyncio.get_event_loop() - def is_disconnected_sync() -> bool: + def is_connected_sync() -> bool: future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop) try: - return not future.result(timeout=0.01) + logger.debug("Checking if connected") + is_connected = not future.result(timeout=0.01) + logger.debug(f"Is connected: {is_connected}") + return is_connected except asyncio.TimeoutError: logger.error("Asyncio timed out") return True @@ -300,7 +303,7 @@ def is_disconnected_sync() -> bool: ) return True - return is_disconnected_sync + return is_connected_sync @router.post("/send-message") @@ -309,7 +312,7 @@ def handle_new_chat_message( request: Request, user: User | None = Depends(current_user), _: None = Depends(check_token_rate_limits), - is_disconnected_func: Callable[[], bool] = Depends(is_disconnected), + is_connected_func: Callable[[], bool] = Depends(is_connected), ) -> StreamingResponse: """ This endpoint is both used for all the following purposes: @@ -325,7 +328,7 @@ def handle_new_chat_message( request (Request): The current HTTP request context. user (User | None): The current user, obtained via dependency injection. _ (None): Rate limit check is run if user/group/global rate limits are enabled. - is_disconnected_func (Callable[[], bool]): Function to check client disconnection, + is_connected_func (Callable[[], bool]): Function to check client disconnection, used to stop the streaming response if the client disconnects. Returns: @@ -354,7 +357,7 @@ def stream_generator() -> Generator[str, None, None]: custom_tool_additional_headers=get_custom_tool_additional_request_headers( request.headers ), - is_connected=is_disconnected_func, + is_connected=is_connected_func, ): yield json.dumps(packet) if isinstance(packet, dict) else packet From 24e34019ce25314c5e749d38dd0895a1c3d5141e Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 31 Oct 2024 15:37:59 -0700 Subject: [PATCH 04/10] More testing --- .../llm/answering/llm_response_handler.py | 3 +- .../server/query_and_chat/chat_backend.py | 45 ++++++++++++------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py index 9c5d8234ffd..b013e9cf892 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -9,6 +9,7 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason from danswer.file_store.models import InMemoryChatFile from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.tools.force import ForceUseTool @@ -71,7 +72,7 @@ def handle_llm_response( all_messages.append(message) if self.is_cancelled(): - # yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) return # potentially give back all info on the selected tool call + its result diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index a3f93287921..9a475f89615 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -1,5 +1,6 @@ import asyncio import io +import time import uuid from collections.abc import Callable from collections.abc import Generator @@ -20,7 +21,6 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import extract_headers -from danswer.chat.process_message import stream_chat_message from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType @@ -73,7 +73,7 @@ from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest from danswer.server.query_and_chat.token_limit import check_token_rate_limits -from danswer.utils.headers import get_custom_tool_additional_request_headers +from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger @@ -347,24 +347,39 @@ def handle_new_chat_message( def stream_generator() -> Generator[str, None, None]: try: - for packet in stream_chat_message( - new_msg_req=chat_message_req, - user=user, - use_existing_user_message=chat_message_req.use_existing_user_message, - litellm_additional_headers=extract_headers( - request.headers, LITELLM_PASS_THROUGH_HEADERS - ), - custom_tool_additional_headers=get_custom_tool_additional_request_headers( - request.headers - ), - is_connected=is_connected_func, - ): - yield json.dumps(packet) if isinstance(packet, dict) else packet + # for packet in stream_chat_message( + # new_msg_req=chat_message_req, + # user=user, + # use_existing_user_message=chat_message_req.use_existing_user_message, + # litellm_additional_headers=extract_headers( + # request.headers, LITELLM_PASS_THROUGH_HEADERS + # ), + # custom_tool_additional_headers=get_custom_tool_additional_request_headers( + # request.headers + # ), + # is_connected=None, + # ): + # yield json.dumps(packet) if isinstance(packet, dict) else packet + yield get_json_line( + {"user_message_id": 1289, "reserved_assistant_message_id": 1290} + ) + for _ in range(50): + logger.debug("Yielding answer piece") + yield get_json_line({"answer_piece": "hello"}) + is_connected = is_connected_func() + logger.debug(f"Is connected: {is_connected}") + time.sleep(1) except Exception as e: logger.exception(f"Error in chat message streaming: {e}") yield json.dumps({"error": str(e)}) + except GeneratorExit: + logger.debug("GeneratorExit") + + finally: + logger.debug("Stream generator finished") + return StreamingResponse(stream_generator(), media_type="text/event-stream") From 2132a430cc64abd869632c0f55a35bdc42b30be9 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 31 Oct 2024 16:26:19 -0700 Subject: [PATCH 05/10] Fix image generation slowness --- backend/danswer/llm/answering/answer.py | 5 +++ .../llm/answering/llm_response_handler.py | 2 + .../answering/tool/tool_response_handler.py | 16 ++++--- .../server/query_and_chat/chat_backend.py | 45 +++++++------------ .../images/image_generation_tool.py | 5 ++- 5 files changed, 37 insertions(+), 36 deletions(-) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 9b48412e44e..c447db452ff 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -18,6 +18,9 @@ from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.answering.prompts.build import default_build_system_message from danswer.llm.answering.prompts.build import default_build_user_message +from danswer.llm.answering.stream_processing.answer_response_handler import ( + AnswerResponseHandler, +) from danswer.llm.answering.stream_processing.answer_response_handler import ( CitationResponseHandler, ) @@ -210,6 +213,8 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: tool_call_handler = ToolResponseHandler(current_llm_call.tools) search_result = SearchTool.get_search_result(current_llm_call) or [] + + answer_handler: AnswerResponseHandler if self.answer_style_config.citation_config: answer_handler = CitationResponseHandler( context_docs=search_result, diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py index b013e9cf892..70ce84572d0 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -8,6 +8,7 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerQuotes from danswer.chat.models import StreamStopInfo from danswer.chat.models import StreamStopReason from danswer.file_store.models import InMemoryChatFile @@ -29,6 +30,7 @@ ResponsePart = ( DanswerAnswerPiece | CitationInfo + | DanswerQuotes | ToolCallKickoff | ToolResponse | ToolCallFinalResult diff --git a/backend/danswer/llm/answering/tool/tool_response_handler.py b/backend/danswer/llm/answering/tool/tool_response_handler.py index 6c4fec77941..08e7284f790 100644 --- a/backend/danswer/llm/answering/tool/tool_response_handler.py +++ b/backend/danswer/llm/answering/tool/tool_response_handler.py @@ -135,14 +135,9 @@ def _handle_tool_call(self) -> Generator[ResponsePart, None, None]: if not selected_tool or not selected_tool_call_request: return + logger.info(f"Selected tool: {selected_tool.name}") + logger.debug(f"Selected tool call request: {selected_tool_call_request}") self.tool_runner = ToolRunner(selected_tool, selected_tool_call_request["args"]) - self.tool_call_summary = ToolCallSummary( - tool_call_request=self.tool_call_chunk, - tool_call_result=build_tool_message( - tool_call_request, self.tool_runner.tool_message_content() - ), - ) - self.tool_kickoff = self.tool_runner.kickoff() yield self.tool_kickoff @@ -153,6 +148,13 @@ def _handle_tool_call(self) -> Generator[ResponsePart, None, None]: self.tool_final_result = self.tool_runner.tool_final_result() yield self.tool_final_result + self.tool_call_summary = ToolCallSummary( + tool_call_request=self.tool_call_chunk, + tool_call_result=build_tool_message( + selected_tool_call_request, self.tool_runner.tool_message_content() + ), + ) + def handle_response_part( self, response_item: BaseMessage | None, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 9a475f89615..073231e2eec 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -1,6 +1,6 @@ import asyncio import io -import time +import json import uuid from collections.abc import Callable from collections.abc import Generator @@ -21,6 +21,7 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import extract_headers +from danswer.chat.process_message import stream_chat_message from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType @@ -73,7 +74,7 @@ from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest from danswer.server.query_and_chat.token_limit import check_token_rate_limits -from danswer.server.utils import get_json_line +from danswer.utils.headers import get_custom_tool_additional_request_headers from danswer.utils.logger import setup_logger @@ -289,9 +290,7 @@ async def is_connected(request: Request) -> Callable[[], bool]: def is_connected_sync() -> bool: future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop) try: - logger.debug("Checking if connected") is_connected = not future.result(timeout=0.01) - logger.debug(f"Is connected: {is_connected}") return is_connected except asyncio.TimeoutError: logger.error("Asyncio timed out") @@ -343,32 +342,22 @@ def handle_new_chat_message( ): raise HTTPException(status_code=400, detail="Empty chat message is invalid") - import json - def stream_generator() -> Generator[str, None, None]: try: - # for packet in stream_chat_message( - # new_msg_req=chat_message_req, - # user=user, - # use_existing_user_message=chat_message_req.use_existing_user_message, - # litellm_additional_headers=extract_headers( - # request.headers, LITELLM_PASS_THROUGH_HEADERS - # ), - # custom_tool_additional_headers=get_custom_tool_additional_request_headers( - # request.headers - # ), - # is_connected=None, - # ): - # yield json.dumps(packet) if isinstance(packet, dict) else packet - yield get_json_line( - {"user_message_id": 1289, "reserved_assistant_message_id": 1290} - ) - for _ in range(50): - logger.debug("Yielding answer piece") - yield get_json_line({"answer_piece": "hello"}) - is_connected = is_connected_func() - logger.debug(f"Is connected: {is_connected}") - time.sleep(1) + for packet in stream_chat_message( + new_msg_req=chat_message_req, + user=user, + use_existing_user_message=chat_message_req.use_existing_user_message, + litellm_additional_headers=extract_headers( + request.headers, LITELLM_PASS_THROUGH_HEADERS + ), + custom_tool_additional_headers=get_custom_tool_additional_request_headers( + request.headers + ), + is_connected=is_connected_func, + ): + logger.debug(f"Yielding packet: {packet}") + yield json.dumps(packet) if isinstance(packet, dict) else packet except Exception as e: logger.exception(f"Error in chat message streaming: {e}") diff --git a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py index 6fb06fb534a..3da53751812 100644 --- a/backend/danswer/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/danswer/tools/tool_implementations/images/image_generation_tool.py @@ -117,7 +117,10 @@ def tool_definition(self) -> dict: }, "shape": { "type": "string", - "description": "Optional. Image shape: 'square', 'portrait', or 'landscape'", + "description": ( + "Optional - only specify if you want a specific shape." + " Image shape: 'square', 'portrait', or 'landscape'." + ), "enum": [shape.value for shape in ImageShape], }, }, From 3620266bddfbf1fca309ff2fe97f72bda7462979 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 31 Oct 2024 16:28:21 -0700 Subject: [PATCH 06/10] Remove unused exception --- backend/danswer/server/query_and_chat/chat_backend.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 073231e2eec..2b764de729d 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -363,9 +363,6 @@ def stream_generator() -> Generator[str, None, None]: logger.exception(f"Error in chat message streaming: {e}") yield json.dumps({"error": str(e)}) - except GeneratorExit: - logger.debug("GeneratorExit") - finally: logger.debug("Stream generator finished") From 98660be16459038b438d12616bd6f00dde418b95 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 31 Oct 2024 16:35:19 -0700 Subject: [PATCH 07/10] Fix UT --- backend/tests/unit/danswer/llm/answering/test_answer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/tests/unit/danswer/llm/answering/test_answer.py b/backend/tests/unit/danswer/llm/answering/test_answer.py index f772f157204..96c791cd47b 100644 --- a/backend/tests/unit/danswer/llm/answering/test_answer.py +++ b/backend/tests/unit/danswer/llm/answering/test_answer.py @@ -293,6 +293,7 @@ def test_answer_with_search_call_quotes_enabled( answer_instance.force_use_tool = ForceUseTool( force_use=False, tool_name="", args=None ) + answer_instance.answer_style_config.citation_config = None answer_instance.answer_style_config.quotes_config = QuotesConfig() # Set up the LLM mock to return search results and then an answer From 18b4a8a26331bc013b49e486e2bf82c5ce4bfe73 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 31 Oct 2024 17:59:53 -0700 Subject: [PATCH 08/10] fix stop generating --- backend/danswer/llm/answering/llm_response_handler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/danswer/llm/answering/llm_response_handler.py b/backend/danswer/llm/answering/llm_response_handler.py index 70ce84572d0..f8426844244 100644 --- a/backend/danswer/llm/answering/llm_response_handler.py +++ b/backend/danswer/llm/answering/llm_response_handler.py @@ -67,16 +67,15 @@ def handle_llm_response( ) -> Generator[ResponsePart, None, None]: all_messages: list[BaseMessage] = [] for message in stream: + if self.is_cancelled(): + yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + return # tool handler doesn't do anything until the full message is received # NOTE: still need to run list() to get this to run list(self.tool_handler.handle_response_part(message, all_messages)) yield from self.answer_handler.handle_response_part(message, all_messages) all_messages.append(message) - if self.is_cancelled(): - yield StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) - return - # potentially give back all info on the selected tool call + its result yield from self.tool_handler.handle_response_part(None, all_messages) yield from self.answer_handler.handle_response_part(None, all_messages) From 7db0de9505c3510a4db76e98a47d5b079056dc93 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 31 Oct 2024 18:47:50 -0700 Subject: [PATCH 09/10] minor typo --- .../llm/answering/stream_processing/quotes_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py index e152463aef5..33913133164 100644 --- a/backend/danswer/llm/answering/stream_processing/quotes_processing.py +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -246,7 +246,7 @@ def process_token( remaining = self.model_output[m.end() :] # Look for an unescaped quote, which means the answer is entirely contained - # in this toekn e.g. if the token is `{"answer": "blah", "qu` + # in this token e.g. if the token is `{"answer": "blah", "qu` quote_indices = [i for i, char in enumerate(remaining) if char == '"'] for quote_idx in quote_indices: # Check if quote is escaped by counting backslashes before it From 5fbcc70518bd5d1be00d6595f3fc690f81c52f21 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 1 Nov 2024 12:34:06 -0700 Subject: [PATCH 10/10] minor logging updates for clarity --- backend/danswer/chat/process_message.py | 1 + backend/danswer/server/query_and_chat/chat_backend.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index bf058d376b6..0394f34b828 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -836,6 +836,7 @@ def stream_chat_message_objects( # Post-LLM answer processing try: + logger.debug("Post-LLM answer processing") message_specific_citations: MessageSpecificCitations | None = None if reference_db_search_docs: message_specific_citations = _translate_citations( diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 2b764de729d..e14f6f25c5d 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -356,7 +356,6 @@ def stream_generator() -> Generator[str, None, None]: ), is_connected=is_connected_func, ): - logger.debug(f"Yielding packet: {packet}") yield json.dumps(packet) if isinstance(packet, dict) else packet except Exception as e: