From 6311b70cc6750dd43ea7aacc5bc676de385f5aab Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 16 Dec 2024 11:23:01 -0800 Subject: [PATCH 01/78] initial onyx changes --- .../answer_query/graph_builder.py | 100 ++++ .../answer_query/nodes/answer_check.py | 30 ++ .../answer_query/nodes/answer_generation.py | 32 ++ .../answer_query/nodes/format_answer.py | 16 + .../onyx/agent_search/answer_query/states.py | 45 ++ backend/onyx/agent_search/core_state.py | 15 + .../onyx/agent_search/deep_answer/edges.py | 0 .../agent_search/deep_answer/graph_builder.py | 0 .../deep_answer/nodes/answer_generation.py | 114 +++++ .../deep_answer/nodes/deep_decomp.py | 78 ++++ .../nodes/entity_term_extraction.py | 40 ++ .../nodes/sub_qa_level_aggregator.py | 30 ++ .../deep_answer/nodes/sub_qa_manager.py | 19 + .../onyx/agent_search/deep_answer/states.py | 0 .../agent_search/expanded_retrieval/edges.py | 44 ++ .../expanded_retrieval/graph_builder.py | 88 ++++ .../expanded_retrieval/nodes/doc_reranking.py | 11 + .../expanded_retrieval/nodes/doc_retrieval.py | 47 ++ .../nodes/doc_verification.py | 60 +++ .../nodes/verification_kickoff.py | 27 ++ .../expanded_retrieval/prompts.py | 0 .../agent_search/expanded_retrieval/states.py | 36 ++ backend/onyx/agent_search/main/edges.py | 61 +++ .../onyx/agent_search/main/graph_builder.py | 98 ++++ .../agent_search/main/nodes/base_decomp.py | 31 ++ .../main/nodes/generate_initial_answer.py | 53 +++ backend/onyx/agent_search/main/states.py | 37 ++ backend/onyx/agent_search/run_graph.py | 27 ++ .../agent_search/shared_graph_utils/models.py | 12 + .../shared_graph_utils/operators.py | 9 + .../shared_graph_utils/prompts.py | 427 ++++++++++++++++++ .../agent_search/shared_graph_utils/utils.py | 101 +++++ backend/requirements/default.txt | 13 +- 33 files changed, 1697 insertions(+), 4 deletions(-) create mode 100644 backend/onyx/agent_search/answer_query/graph_builder.py create mode 100644 backend/onyx/agent_search/answer_query/nodes/answer_check.py create mode 100644 backend/onyx/agent_search/answer_query/nodes/answer_generation.py create mode 100644 backend/onyx/agent_search/answer_query/nodes/format_answer.py create mode 100644 backend/onyx/agent_search/answer_query/states.py create mode 100644 backend/onyx/agent_search/core_state.py create mode 100644 backend/onyx/agent_search/deep_answer/edges.py create mode 100644 backend/onyx/agent_search/deep_answer/graph_builder.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/answer_generation.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py create mode 100644 backend/onyx/agent_search/deep_answer/states.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/edges.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/graph_builder.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/prompts.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/states.py create mode 100644 backend/onyx/agent_search/main/edges.py create mode 100644 backend/onyx/agent_search/main/graph_builder.py create mode 100644 backend/onyx/agent_search/main/nodes/base_decomp.py create mode 100644 backend/onyx/agent_search/main/nodes/generate_initial_answer.py create mode 100644 backend/onyx/agent_search/main/states.py create mode 100644 backend/onyx/agent_search/run_graph.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/models.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/operators.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/prompts.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/utils.py diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py new file mode 100644 index 0000000000..e52bfe28d6 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -0,0 +1,100 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.answer_query.nodes.answer_check import answer_check +from onyx.agent_search.answer_query.nodes.answer_generation import answer_generation +from onyx.agent_search.answer_query.nodes.format_answer import format_answer +from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) + + +def answer_query_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=AnswerQueryState, + input=AnswerQueryInput, + output=AnswerQueryOutput, + ) + + ### Add nodes ### + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval_for_initial_decomp", + action=expanded_retrieval, + ) + graph.add_node( + node="answer_check", + action=answer_check, + ) + graph.add_node( + node="answer_generation", + action=answer_generation, + ) + graph.add_node( + node="format_answer", + action=format_answer, + ) + + ### Add edges ### + + graph.add_edge( + start_key=START, + end_key="expanded_retrieval_for_initial_decomp", + ) + graph.add_edge( + start_key="expanded_retrieval_for_initial_decomp", + end_key="answer_generation", + ) + graph.add_edge( + start_key="answer_generation", + end_key="answer_check", + ) + graph.add_edge( + start_key="answer_check", + end_key="format_answer", + ) + graph.add_edge( + start_key="format_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = answer_query_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = AnswerQueryInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + query_to_answer="Who made Excel?", + ) + output = compiled_graph.invoke( + input=inputs, + # debug=True, + # subgraphs=True, + ) + print(output) + # for namespace, chunk in compiled_graph.stream( + # input=inputs, + # # debug=True, + # subgraphs=True, + # ): + # print(namespace) + # print(chunk) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py new file mode 100644 index 0000000000..8b58129c47 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -0,0 +1,30 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_query.states import QACheckOutput +from onyx.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT + + +def answer_check(state: AnswerQueryState) -> QACheckOutput: + msg = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format( + question=state["search_request"].query, + base_answer=state["answer"], + ) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + response_str = merge_message_runs(response, chunk_separator="")[0].content + + return QACheckOutput( + answer_quality=response_str, + ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py new file mode 100644 index 0000000000..c23f77ee70 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -0,0 +1,32 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_query.states import QAGenerationOutput +from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: + query = state["query_to_answer"] + docs = state["reordered_documents"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs)) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + answer_str = merge_message_runs(response, chunk_separator="")[0].content + return QAGenerationOutput( + answer=answer_str, + ) diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py new file mode 100644 index 0000000000..8359baec9b --- /dev/null +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -0,0 +1,16 @@ +from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_query.states import SearchAnswerResults + + +def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: + return AnswerQueryOutput( + decomp_answer_results=[ + SearchAnswerResults( + query=state["query_to_answer"], + quality=state["answer_quality"], + answer=state["answer"], + documents=state["reordered_documents"], + ) + ], + ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py new file mode 100644 index 0000000000..9f8fe12ab6 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/states.py @@ -0,0 +1,45 @@ +from typing import Annotated +from typing import TypedDict + +from pydantic import BaseModel + +from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +class SearchAnswerResults(BaseModel): + query: str + answer: str + quality: str + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class QACheckOutput(TypedDict, total=False): + answer_quality: str + + +class QAGenerationOutput(TypedDict, total=False): + answer: str + + +class ExpandedRetrievalOutput(TypedDict): + reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class AnswerQueryState( + PrimaryState, + QACheckOutput, + QAGenerationOutput, + ExpandedRetrievalOutput, + total=True, +): + query_to_answer: str + + +class AnswerQueryInput(PrimaryState, total=True): + query_to_answer: str + + +class AnswerQueryOutput(TypedDict): + decomp_answer_results: list[SearchAnswerResults] diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py new file mode 100644 index 0000000000..fcd8bddf3e --- /dev/null +++ b/backend/onyx/agent_search/core_state.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from sqlalchemy.orm import Session + +from onyx.context.search.models import SearchRequest +from onyx.llm.interfaces import LLM + + +class PrimaryState(TypedDict, total=False): + search_request: SearchRequest + primary_llm: LLM + fast_llm: LLM + # a single session for the entire agent search + # is fine if we are only reading + db_session: Session diff --git a/backend/onyx/agent_search/deep_answer/edges.py b/backend/onyx/agent_search/deep_answer/edges.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/onyx/agent_search/deep_answer/graph_builder.py b/backend/onyx/agent_search/deep_answer/graph_builder.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py new file mode 100644 index 0000000000..f0a94b398a --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py @@ -0,0 +1,114 @@ +from typing import Any + +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT +from onyx.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs +from onyx.agent_search.shared_graph_utils.utils import normalize_whitespace + + +# aggregate sub questions and answers +def deep_answer_generation(state: MainState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---DEEP GENERATE---") + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + deep_answer_context = state["core_answer_dynamic_context"] + + print(f"Number of verified retrieval docs - deep: {len(docs)}") + + combined_context = normalize_whitespace( + COMBINED_CONTEXT.format( + deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) + ) + ) + + msg = [ + HumanMessage( + content=MODIFIED_RAG_PROMPT.format( + question=question, combined_context=combined_context + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + return { + "deep_answer": response.content, + } + + +def final_stuff(state: MainState) -> dict[str, Any]: + """ + Invokes the agent model to generate a response based on the current state. Given + the question, it will decide to retrieve using the retriever tool, or simply end. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with the agent response appended to messages + """ + print("---FINAL---") + + messages = state["log_messages"] + time_ordered_messages = [x.pretty_repr() for x in messages] + time_ordered_messages.sort() + + print("Message Log:") + print("\n".join(time_ordered_messages)) + + initial_sub_qas = state["initial_sub_qas"] + initial_sub_qa_list = [] + for initial_sub_qa in initial_sub_qas: + if initial_sub_qa["sub_answer_check"] == "yes": + initial_sub_qa_list.append( + f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----' + ) + + initial_sub_qa_context = "\n".join(initial_sub_qa_list) + + base_answer = state["base_answer"] + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") + print("--------------------------------") + + if not state.get("deep_answer"): + print("No Deep Answer was required") + return {} + + deep_answer = state["deep_answer"] + sub_qas = state["sub_qas"] + sub_qa_list = [] + for sub_qa in sub_qas: + if sub_qa["sub_answer_check"] == "yes": + sub_qa_list.append( + f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' + ) + + sub_qa_context = "\n".join(sub_qa_list) + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Final Deep Answer:\n{deep_answer}") + print("--------------------------------") + print("Sub Questions and Answers:") + print(sub_qa_context) + + return {} diff --git a/backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py b/backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py new file mode 100644 index 0000000000..786b2774fc --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py @@ -0,0 +1,78 @@ +import json +import re +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction +from onyx.agent_search.shared_graph_utils.utils import generate_log_message + + +def decompose(state: MainState) -> dict[str, Any]: + """ """ + + node_start_time = datetime.now() + + question = state["original_question"] + base_answer = state["base_answer"] + + # get the entity term extraction dict and properly format it + entity_term_extraction_dict = state["retrieved_entities_relationships"][ + "retrieved_entities_relationships" + ] + + entity_term_extraction_str = format_entity_term_extraction( + entity_term_extraction_dict + ) + + initial_question_answers = state["initial_sub_qas"] + + addressed_question_list = [ + x["sub_question"] + for x in initial_question_answers + if x["sub_answer_check"] == "yes" + ] + failed_question_list = [ + x["sub_question"] + for x in initial_question_answers + if x["sub_answer_check"] == "no" + ] + + msg = [ + HumanMessage( + content=DEEP_DECOMPOSE_PROMPT.format( + question=question, + entity_term_extraction_str=entity_term_extraction_str, + base_answer=base_answer, + answered_sub_questions="\n - ".join(addressed_question_list), + failed_sub_questions="\n - ".join(failed_question_list), + ), + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr()) + parsed_response = json.loads(cleaned_response) + + sub_questions_dict = {} + for sub_question_nr, sub_question_dict in enumerate( + parsed_response["sub_questions"] + ): + sub_question_dict["answered"] = False + sub_question_dict["verified"] = False + sub_questions_dict[sub_question_nr] = sub_question_dict + + return { + "decomposed_sub_questions_dict": sub_questions_dict, + "log_messages": generate_log_message( + message="deep - decompose", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py b/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py new file mode 100644 index 0000000000..865a78f0a7 --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py @@ -0,0 +1,40 @@ +import json +import re +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def entity_term_extraction(state: MainState) -> dict[str, Any]: + """Extract entities and terms from the question and context""" + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + doc_context = format_docs(docs) + + msg = [ + HumanMessage( + content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), + ) + ] + fast_llm = state["fast_llm"] + # Grader + llm_response_list = list( + fast_llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + cleaned_response = re.sub(r"```json\n|\n```", "", llm_response) + parsed_response = json.loads(cleaned_response) + + return { + "retrieved_entities_relationships": parsed_response, + } diff --git a/backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py new file mode 100644 index 0000000000..5805b3c632 --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py @@ -0,0 +1,30 @@ +from typing import Any + +from onyx.agent_search.main.states import MainState + + +# aggregate sub questions and answers +def sub_qa_level_aggregator(state: MainState) -> dict[str, Any]: + sub_qas = state["sub_qas"] + + dynamic_context_list = [ + "Below you will find useful information to answer the original question:" + ] + checked_sub_qas = [] + + for core_answer_sub_qa in sub_qas: + question = core_answer_sub_qa["sub_question"] + answer = core_answer_sub_qa["sub_answer"] + verified = core_answer_sub_qa["sub_answer_check"] + + if verified == "yes": + dynamic_context_list.append( + f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n" + ) + checked_sub_qas.append({"sub_question": question, "sub_answer": answer}) + dynamic_context = "\n".join(dynamic_context_list) + + return { + "core_answer_dynamic_context": dynamic_context, + "checked_sub_qas": checked_sub_qas, + } diff --git a/backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py new file mode 100644 index 0000000000..58b4262cdc --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py @@ -0,0 +1,19 @@ +from typing import Any + +from onyx.agent_search.main.states import MainState + + +def sub_qa_manager(state: MainState) -> dict[str, Any]: + """ """ + + sub_questions_dict = state["decomposed_sub_questions_dict"] + + sub_questions = {} + + for sub_question_nr, sub_question_dict in sub_questions_dict.items(): + sub_questions[sub_question_nr] = sub_question_dict["sub_question"] + + return { + "sub_questions": sub_questions, + "num_new_question_iterations": 0, + } diff --git a/backend/onyx/agent_search/deep_answer/states.py b/backend/onyx/agent_search/deep_answer/states.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py new file mode 100644 index 0000000000..2c63125bb9 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -0,0 +1,44 @@ +from collections.abc import Hashable + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langgraph.types import Send + +from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from onyx.llm.interfaces import LLM + + +def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]: + print(f"parallel_retrieval_edge state: {state.keys()}") + + # This should be better... + question = state.get("query_to_answer") or state["search_request"].query + llm: LLM = state["fast_llm"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + llm_response_list = list( + llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + print(f"llm_response: {llm_response}") + + rewritten_queries = llm_response.split("\n") + + print(f"rewritten_queries: {rewritten_queries}") + + return [ + Send( + "doc_retrieval", + RetrieveInput(query_to_retrieve=query, **state), + ) + for query in rewritten_queries + ] diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py new file mode 100644 index 0000000000..1928e93450 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -0,0 +1,88 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.expanded_retrieval.edges import parallel_retrieval_edge +from onyx.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking +from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval +from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( + doc_verification, +) +from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( + verification_kickoff, +) +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def expanded_retrieval_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=ExpandedRetrievalState, + input=ExpandedRetrievalInput, + output=ExpandedRetrievalOutput, + ) + + ### Add nodes ### + + graph.add_node( + node="doc_retrieval", + action=doc_retrieval, + ) + graph.add_node( + node="verification_kickoff", + action=verification_kickoff, + ) + graph.add_node( + node="doc_verification", + action=doc_verification, + ) + graph.add_node( + node="doc_reranking", + action=doc_reranking, + ) + + ### Add edges ### + + graph.add_conditional_edges( + source=START, + path=parallel_retrieval_edge, + path_map=["doc_retrieval"], + ) + graph.add_edge( + start_key="doc_retrieval", + end_key="verification_kickoff", + ) + graph.add_edge( + start_key="doc_verification", + end_key="doc_reranking", + ) + graph.add_edge( + start_key="doc_reranking", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = expanded_retrieval_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = ExpandedRetrievalInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + query_to_answer="Who made Excel?", + ) + for thing in compiled_graph.stream(inputs, debug=True): + print(thing) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py new file mode 100644 index 0000000000..1ac3620351 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -0,0 +1,11 @@ +from onyx.agent_search.expanded_retrieval.states import DocRerankingOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput: + print(f"doc_reranking state: {state.keys()}") + + verified_documents = state["verified_documents"] + reranked_documents = verified_documents + + return DocRerankingOutput(reranked_documents=reranked_documents) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py new file mode 100644 index 0000000000..8d61249948 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -0,0 +1,47 @@ +from onyx.agent_search.expanded_retrieval.states import DocRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.context.search.models import InferenceSection +from onyx.context.search.models import SearchRequest +from onyx.context.search.pipeline import SearchPipeline +from onyx.db.engine import get_session_context_manager + + +class RetrieveInput(ExpandedRetrievalState): + query_to_retrieve: str + + +def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: + # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print(f"doc_retrieval state: {state.keys()}") + + state["query_to_retrieve"] + + documents: list[InferenceSection] = [] + llm = state["primary_llm"] + fast_llm = state["fast_llm"] + # db_session = state["db_session"] + query_to_retrieve = state["search_request"].query + with get_session_context_manager() as db_session1: + documents = SearchPipeline( + search_request=SearchRequest( + query=query_to_retrieve, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=db_session1, + ).reranked_sections + + print(f"retrieved documents: {len(documents)}") + return DocRetrievalOutput( + retrieved_documents=documents, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py new file mode 100644 index 0000000000..f3f993e87b --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -0,0 +1,60 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.expanded_retrieval.states import DocVerificationOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.shared_graph_utils.models import BinaryDecision +from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from onyx.context.search.models import InferenceSection + + +class DocVerificationInput(ExpandedRetrievalState, total=True): + doc_to_verify: InferenceSection + + +def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print(f"doc_verification state: {state.keys()}") + + original_query = state["search_request"].query + doc_to_verify = state["doc_to_verify"] + document_content = doc_to_verify.combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=original_query, document_content=document_content + ) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + response_string = merge_message_runs(response, chunk_separator="")[0].content + # Convert string response to proper dictionary format + decision_dict = {"decision": response_string.lower()} + formatted_response = BinaryDecision.model_validate(decision_dict) + + print(f"Verdict: {formatted_response.decision}") + + verified_documents = [] + if formatted_response.decision == "yes": + verified_documents.append(doc_to_verify) + + return DocVerificationOutput( + verified_documents=verified_documents, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py new file mode 100644 index 0000000000..d40bf6f0da --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -0,0 +1,27 @@ +from typing import Literal + +from langgraph.types import Command +from langgraph.types import Send + +from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( + DocVerificationInput, +) +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def verification_kickoff( + state: ExpandedRetrievalState, +) -> Command[Literal["doc_verification"]]: + print(f"verification_kickoff state: {state.keys()}") + + documents = state["retrieved_documents"] + return Command( + update={}, + goto=[ + Send( + node="doc_verification", + arg=DocVerificationInput(doc_to_verify=doc, **state), + ) + for doc in documents + ], + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/prompts.py b/backend/onyx/agent_search/expanded_retrieval/prompts.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py new file mode 100644 index 0000000000..a0f726b7f8 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -0,0 +1,36 @@ +from typing import Annotated +from typing import TypedDict + +from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +class DocRetrievalOutput(TypedDict, total=False): + retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocVerificationOutput(TypedDict, total=False): + verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocRerankingOutput(TypedDict, total=False): + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class ExpandedRetrievalState( + PrimaryState, + DocRetrievalOutput, + DocVerificationOutput, + DocRerankingOutput, + total=True, +): + query_to_answer: str + + +class ExpandedRetrievalInput(PrimaryState, total=True): + query_to_answer: str + + +class ExpandedRetrievalOutput(TypedDict): + reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py new file mode 100644 index 0000000000..953b0a9627 --- /dev/null +++ b/backend/onyx/agent_search/main/edges.py @@ -0,0 +1,61 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.main.states import MainState + + +def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: + return [ + Send( + "answer_query", + AnswerQueryInput( + **state, + query_to_answer=query, + ), + ) + for query in state["initial_decomp_queries"] + ] + + +# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: +# # Routes re-written queries to the (parallel) retrieval steps +# # Notice the 'Send()' API that takes care of the parallelization +# return [ +# Send( +# "sub_answers_graph", +# ResearchQAState( +# sub_question=sub_question["sub_question_str"], +# sub_question_nr=sub_question["sub_question_nr"], +# graph_start_time=state["graph_start_time"], +# primary_llm=state["primary_llm"], +# fast_llm=state["fast_llm"], +# ), +# ) +# for sub_question in state["sub_questions"] +# ] + + +# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: +# print("---GO TO DEEP ANSWER OR END---") + +# base_answer = state["base_answer"] + +# question = state["original_question"] + +# BASE_CHECK_MESSAGE = [ +# HumanMessage( +# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) +# ) +# ] + +# model = state["fast_llm"] +# response = model.invoke(BASE_CHECK_MESSAGE) + +# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}") + +# if response.pretty_repr() == "no": +# return "decompose" +# else: +# return "end" diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py new file mode 100644 index 0000000000..449ffb89df --- /dev/null +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -0,0 +1,98 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.answer_query.graph_builder import answer_query_graph_builder +from onyx.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) +from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries +from onyx.agent_search.main.nodes.base_decomp import main_decomp_base +from onyx.agent_search.main.nodes.generate_initial_answer import ( + generate_initial_answer, +) +from onyx.agent_search.main.states import MainInput +from onyx.agent_search.main.states import MainState + + +def main_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=MainState, + input=MainInput, + ) + + ### Add nodes ### + + graph.add_node( + node="base_decomp", + action=main_decomp_base, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query", + action=answer_query_subgraph, + ) + expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval", + action=expanded_retrieval_subgraph, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + + ### Add edges ### + graph.add_edge( + start_key=START, + end_key="expanded_retrieval", + ) + + graph.add_edge( + start_key=START, + end_key="base_decomp", + ) + graph.add_conditional_edges( + source="base_decomp", + path=parallelize_decompozed_answer_queries, + path_map=["answer_query"], + ) + graph.add_edge( + start_key=["answer_query", "expanded_retrieval"], + end_key="generate_initial_answer", + ) + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = main_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="If i am familiar with the function that I need, how can I type it into a cell?", + ) + with get_session_context_manager() as db_session: + inputs = MainInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + ) + for thing in compiled_graph.stream( + input=inputs, + # stream_mode="debug", + # debug=True, + subgraphs=True, + ): + # print(thing) + print() + print() diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py new file mode 100644 index 0000000000..28e93c6cbc --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -0,0 +1,31 @@ +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import BaseDecompOutput +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT +from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string + + +def main_decomp_base(state: MainState) -> BaseDecompOutput: + question = state["search_request"].query + + msg = [ + HumanMessage( + content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + model = state["fast_llm"] + response = model.invoke(msg) + + content = response.pretty_repr() + list_of_subquestions = clean_and_parse_list_string(content) + + decomp_list: list[str] = [ + sub_question["sub_question"].strip() for sub_question in list_of_subquestions + ] + + return BaseDecompOutput( + initial_decomp_queries=decomp_list, + ) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py new file mode 100644 index 0000000000..5671b2352f --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -0,0 +1,53 @@ +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import InitialAnswerOutput +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def generate_initial_answer(state: MainState) -> InitialAnswerOutput: + print("---GENERATE INITIAL---") + + question = state["search_request"].query + docs = state["documents"] + + decomp_answer_results = state["decomp_answer_results"] + + good_qa_list: list[str] = [] + + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + for decomp_answer_result in decomp_answer_results: + if ( + decomp_answer_result.quality.lower() == "yes" + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != "I don't know" + ): + good_qa_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.query, + sub_answer=decomp_answer_result.answer, + ) + ) + + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + answered_sub_questions=sub_question_answer_str, + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + answer = response.pretty_repr() + + print(answer) + return InitialAnswerOutput(initial_answer=answer) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py new file mode 100644 index 0000000000..3b753ff847 --- /dev/null +++ b/backend/onyx/agent_search/main/states.py @@ -0,0 +1,37 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from onyx.agent_search.answer_query.states import SearchAnswerResults +from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +class BaseDecompOutput(TypedDict, total=False): + initial_decomp_queries: list[str] + + +class InitialAnswerOutput(TypedDict, total=False): + initial_answer: str + + +class MainState( + PrimaryState, + BaseDecompOutput, + InitialAnswerOutput, + total=True, +): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: Annotated[list[SearchAnswerResults], add] + + +class MainInput(PrimaryState, total=True): + pass + + +class MainOutput(TypedDict): + """ + This is not used because defining the output only matters for filtering the output of + a .invoke() call but we are streaming so we just yield the entire state. + """ diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py new file mode 100644 index 0000000000..98ed0ff8e6 --- /dev/null +++ b/backend/onyx/agent_search/run_graph.py @@ -0,0 +1,27 @@ +from onyx.agent_search.primary_graph.graph_builder import build_core_graph +from onyx.llm.answering.answer import AnswerStream +from onyx.llm.interfaces import LLM +from onyx.tools.tool import Tool + + +def run_graph( + query: str, + llm: LLM, + tools: list[Tool], +) -> AnswerStream: + graph = build_core_graph() + + inputs = { + "original_query": query, + "messages": [], + "tools": tools, + "llm": llm, + } + compiled_graph = graph.compile() + output = compiled_graph.invoke(input=inputs) + yield from output + + +if __name__ == "__main__": + pass + # run_graph("What is the capital of France?", llm, []) diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py new file mode 100644 index 0000000000..162d651fe5 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -0,0 +1,12 @@ +from typing import Literal + +from pydantic import BaseModel + + +# Pydantic models for structured outputs +class RewrittenQueries(BaseModel): + rewritten_queries: list[str] + + +class BinaryDecision(BaseModel): + decision: Literal["yes", "no"] diff --git a/backend/onyx/agent_search/shared_graph_utils/operators.py b/backend/onyx/agent_search/shared_graph_utils/operators.py new file mode 100644 index 0000000000..d75eb54cd5 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/operators.py @@ -0,0 +1,9 @@ +from onyx.chat.prune_and_merge import _merge_sections +from onyx.context.search.models import InferenceSection + + +def dedup_inference_sections( + list1: list[InferenceSection], list2: list[InferenceSection] +) -> list[InferenceSection]: + deduped = _merge_sections(list1 + list2) + return deduped diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py new file mode 100644 index 0000000000..a3eeba29fb --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -0,0 +1,427 @@ +REWRITE_PROMPT_MULTI_ORIGINAL = """ \n + Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a + document store. Particularly, try to think about resolving ambiguities and make the search queries more specific, + enabling the system to search more broadly. + Also, try to make the search queries not redundant, i.e. not too similar! \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """ + +REWRITE_PROMPT_MULTI = """ \n + Please create a list of 2-3 sample documents that could answer an original question. Each document + should be about as long as the original question. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """ + +BASE_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the context provided below - and only the + provided context - to answer the question. If you don't know the answer or if the provided context is + empty, just say "I don't know". Do not use your internal knowledge! + + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + Use three sentences maximum and keep the answer concise. + answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n + \n\n + Answer:""" + +BASE_CHECK_PROMPT = """ \n + Please check whether 1) the suggested answer seems to fully address the original question AND 2)the + original question requests a simple, factual answer, and there are no ambiguities, judgements, + aggregations, or any other complications that may require extra context. (I.e., if the question is + somewhat addressed, but the answer would benefit from more context, then answer with 'no'.) + + Please only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the proposed answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no:""" + +VERIFIER_PROMPT = """ \n + Please check whether the document seems to be relevant for the answer of the question. Please + only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the document text: + \n ------- \n + {document_content} + \n ------- \n + Please answer with yes or no:""" + +INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of subquestions: + + Answer: + """ + +REWRITE_PROMPT_SINGLE = """ \n + Please convert an initial user question into a more appropriate search query for retrievel from a + document store. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Formulate the query: """ + +MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below + - and only this context - to answer the question. If you don't know the answer, just say "I don't know". + Use three sentences maximum and keep the answer concise. + Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer. + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + \nQuestion: {question} + \nContext: {combined_context} \n + + Answer:""" + +ORIG_DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 3-5 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Please think through it step by step and then generate the list of json dictionaries with the following + format: + + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 4-6 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please also provide a search term that can be used to retrieve relevant + documents from a document store. + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Generate the list of json dictionaries with the following format: + + {{"sub_questions": [{{"sub_question": , + "search_term": }}, + ...]}} """ + +DECOMPOSE_PROMPT = """ \n + For an initial user question, please generate at 5-10 individual sub-questions whose answers would help + \n to answer the initial question. The individual questions should be answerable by a good RAG system. + So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question. + + In order to arrive at meaningful sub-questions, please also consider the context retrieved from the + document store, expressed as entities, relationships and terms. You can also think about the types + mentioned in brackets + + Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + and or resolve ambiguities + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Don't be too specific unless the original question is specific. + Please think through it step by step and then generate the list of json dictionaries with the following + format: + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +#### Consolidations +COMBINED_CONTEXT = """------- + Below you will find useful information to answer the original question. First, you see a number of + sub-questions with their answers. This information should be considered to be more focussed and + somewhat more specific to the original question as it tries to contextualized facts. + After that will see the documents that were considered to be relevant to answer the original question. + + Here are the sub-questions and their answers: + \n\n {deep_answer_context} \n\n + \n\n Here are the documents that were considered to be relevant to answer the original question: + \n\n {formated_docs} \n\n + ---------------- + """ + +SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- + Below you will find a question that we ultimately want to answer (the original question) and a list of + motivations in arbitrary order for generated sub-questions that are supposed to help us answering the + original question. The motivations are formatted as : . + (Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant + motivation and 2 is less relevant.) + + Please rank the motivations in order of relevance for answering the original question. Also, try to + ensure that the top questions do not duplicate too much, i.e. that they are not too similar. + Ultimately, create a list with the motivation numbers where the number of the most relevant + motivations comes first. + + Here is the original question: + \n\n {original_question} \n\n + \n\n Here is the list of sub-question motivations: + \n\n {sub_question_explanations} \n\n + ---------------- + + Please think step by step and then generate the ranked list of motivations. + + Please format your answer as a json object in the following format: + {{"reasonning": , + "ranked_motivations": }} + """ + + +INITIAL_DECOMPOSITION_PROMPT = """ \n + Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + For each sub-question, please also create one search term that can be used to retrieve relevant + documents from a document store. + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of json objects with the following format: + + [{{"sub_question": , "search_term": }}, ...] + + Answer: + """ + +INITIAL_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the information provided below - and only the + provided information - to answer the provided question. + + The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. + + If you don't know the answer or if the provided information is empty or insufficient, just say + "I don't know". Do not use your internal knowledge! + + Again, only use the provided informationand do not use your internal knowledge! It is a matter of life + and death that you do NOT use your internal knowledge, just the provided information! + + Try to keep your answer concise. + + And here is the question and the provided information: + \n + \nQuestion:\n {question} + + \nAnswered Sub-questions:\n {answered_sub_questions} + + \nContext:\n {context} \n\n + \n\n + + Answer:""" + +ENTITY_TERM_PROMPT = """ \n + Based on the original question and the context retieved from a dataset, please generate a list of + entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts + (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. + + \n\n + Here is the original question: + \n ------- \n + {question} + \n ------- \n + And here is the context retrieved: + \n ------- \n + {context} + \n ------- \n + + Please format your answer as a json object in the following format: + + {{"retrieved_entities_relationships": {{ + "entities": [{{ + "entity_name": , + "entity_type": + }}], + "relationships": [{{ + "name": , + "type": , + "entities": [, ] + }}], + "terms": [{{ + "term_name": , + "term_type": , + "similar_to": + }}] + }} + }} + """ diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py new file mode 100644 index 0000000000..a435860320 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -0,0 +1,101 @@ +import ast +import json +import re +from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta +from typing import Any + +from onyx.context.search.models import InferenceSection + + +def normalize_whitespace(text: str) -> str: + """Normalize whitespace in text to single spaces and strip leading/trailing whitespace.""" + import re + + return re.sub(r"\s+", " ", text.strip()) + + +# Post-processing +def format_docs(docs: Sequence[InferenceSection]) -> str: + return "\n\n".join(doc.combined_content for doc in docs) + + +def clean_and_parse_list_string(json_string: str) -> list[dict]: + # Remove any prefixes/labels before the actual JSON content + json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL) + + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + + # Try parsing with json.loads first, fall back to ast.literal_eval + try: + return json.loads(cleaned_string) + except json.JSONDecodeError: + try: + return ast.literal_eval(cleaned_string) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e + + +def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + # Parse the cleaned string into a Python dictionary + return json.loads(cleaned_string) + + +def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str: + entities = entity_term_extraction_dict["entities"] + terms = entity_term_extraction_dict["terms"] + relationships = entity_term_extraction_dict["relationships"] + + entity_strs = ["\nEntities:\n"] + for entity in entities: + entity_str = f"{entity['entity_name']} ({entity['entity_type']})" + entity_strs.append(entity_str) + + entity_str = "\n - ".join(entity_strs) + + relationship_strs = ["\n\nRelationships:\n"] + for relationship in relationships: + relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}" + relationship_strs.append(relationship_str) + + relationship_str = "\n - ".join(relationship_strs) + + term_strs = ["\n\nTerms:\n"] + for term in terms: + term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}" + term_strs.append(term_str) + + term_str = "\n - ".join(term_strs) + + return "\n".join(entity_strs + relationship_strs + term_strs) + + +def _format_time_delta(time: timedelta) -> str: + seconds_from_start = f"{((time).seconds):03d}" + microseconds_from_start = f"{((time).microseconds):06d}" + return f"{seconds_from_start}.{microseconds_from_start}" + + +def generate_log_message( + message: str, + node_start_time: datetime, + graph_start_time: datetime | None = None, +) -> str: + current_time = datetime.now() + + if graph_start_time is not None: + graph_time_str = _format_time_delta(current_time - graph_start_time) + else: + graph_time_str = "N/A" + + node_time_str = _format_time_delta(current_time - node_start_time) + + return f"{graph_time_str} ({node_time_str} s): {message}" diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 3a4996d901..01a99c975f 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -26,10 +26,15 @@ huggingface-hub==0.20.1 jira==3.5.1 jsonref==1.1.0 trafilatura==1.12.2 -langchain==0.1.17 -langchain-core==0.1.50 -langchain-text-splitters==0.0.1 -litellm==1.54.1 +langchain==0.3.7 +langchain-core==0.3.24 +langchain-openai==0.2.9 +langchain-text-splitters==0.3.2 +langchainhub==0.1.21 +langgraph==0.2.59 +langgraph-checkpoint==2.0.5 +langgraph-sdk==0.1.44 +litellm==1.53.1 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.9.45 From 11ce2a62abf893ec1857e7a9bf4830b4a4b23fef Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 16 Dec 2024 12:24:17 -0800 Subject: [PATCH 02/78] fix: update staged changes --- backend/onyx/agent_search/run_graph.py | 6 +++--- backend/onyx/tools/message.py | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 98ed0ff8e6..9a93dbba64 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,5 +1,5 @@ -from onyx.agent_search.primary_graph.graph_builder import build_core_graph -from onyx.llm.answering.answer import AnswerStream +from onyx.agent_search.main.graph_builder import main_graph_builder +from onyx.chat.answer import AnswerStream from onyx.llm.interfaces import LLM from onyx.tools.tool import Tool @@ -9,7 +9,7 @@ def run_graph( llm: LLM, tools: list[Tool], ) -> AnswerStream: - graph = build_core_graph() + graph = main_graph_builder() inputs = { "original_query": query, diff --git a/backend/onyx/tools/message.py b/backend/onyx/tools/message.py index d559011162..659f38731e 100644 --- a/backend/onyx/tools/message.py +++ b/backend/onyx/tools/message.py @@ -25,6 +25,11 @@ class ToolCallSummary(BaseModel__v1): tool_call_request: AIMessage tool_call_result: ToolMessage + # This is a workaround to allow arbitrary types in the model + # TODO: Remove this once we have a better solution + class Config: + arbitrary_types_allowed = True + def tool_call_tokens( tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer From 82914ad365cc94ec5adfae517301dce3d9987ae6 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 16 Dec 2024 13:26:09 -0800 Subject: [PATCH 03/78] fixed key issue --- .../onyx/agent_search/answer_query/nodes/answer_generation.py | 2 +- backend/onyx/agent_search/answer_query/nodes/format_answer.py | 2 +- backend/onyx/agent_search/answer_query/states.py | 2 +- backend/onyx/agent_search/expanded_retrieval/states.py | 2 +- backend/onyx/agent_search/main/graph_builder.py | 1 - 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index c23f77ee70..18c0862e23 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -9,7 +9,7 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: query = state["query_to_answer"] - docs = state["reordered_documents"] + docs = state["reranked_documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 8359baec9b..51f7dbad5b 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -10,7 +10,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: query=state["query_to_answer"], quality=state["answer_quality"], answer=state["answer"], - documents=state["reordered_documents"], + documents=state["reranked_documents"], ) ], ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index 9f8fe12ab6..d2dd1f12c6 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -24,7 +24,7 @@ class QAGenerationOutput(TypedDict, total=False): class ExpandedRetrievalOutput(TypedDict): - reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] class AnswerQueryState( diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index a0f726b7f8..54fa6023cc 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -33,4 +33,4 @@ class ExpandedRetrievalInput(PrimaryState, total=True): class ExpandedRetrievalOutput(TypedDict): - reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 449ffb89df..930d7a745f 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -95,4 +95,3 @@ def main_graph_builder() -> StateGraph: ): # print(thing) print() - print() From ff03d717f37f4dc3f3027b338b5e365f223cb112 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 12:36:28 -0800 Subject: [PATCH 04/78] brough over joachim changes --- .../answer_query/nodes/answer_check.py | 6 ++-- .../agent_search/expanded_retrieval/edges.py | 6 ++-- .../expanded_retrieval/nodes/doc_retrieval.py | 28 ++++++++----------- .../shared_graph_utils/prompts.py | 18 +++++++++++- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py index 8b58129c47..9505967306 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -3,14 +3,14 @@ from onyx.agent_search.answer_query.states import AnswerQueryState from onyx.agent_search.answer_query.states import QACheckOutput -from onyx.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( - content=BASE_CHECK_PROMPT.format( - question=state["search_request"].query, + content=SUB_CHECK_PROMPT.format( + question=state["query_to_answer"], base_answer=state["answer"], ) ) diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 2c63125bb9..063befe85a 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -6,7 +6,7 @@ from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.llm.interfaces import LLM @@ -19,7 +19,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab msg = [ HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), ) ] llm_response_list = list( @@ -31,7 +31,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab print(f"llm_response: {llm_response}") - rewritten_queries = llm_response.split("\n") + rewritten_queries = llm_response.split("--") print(f"rewritten_queries: {rewritten_queries}") diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 8d61249948..af38e5f490 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -3,7 +3,6 @@ from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline -from onyx.db.engine import get_session_context_manager class RetrieveInput(ExpandedRetrievalState): @@ -23,25 +22,22 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: """ print(f"doc_retrieval state: {state.keys()}") - state["query_to_retrieve"] - documents: list[InferenceSection] = [] llm = state["primary_llm"] fast_llm = state["fast_llm"] - # db_session = state["db_session"] - query_to_retrieve = state["search_request"].query - with get_session_context_manager() as db_session1: - documents = SearchPipeline( - search_request=SearchRequest( - query=query_to_retrieve, - ), - user=None, - llm=llm, - fast_llm=fast_llm, - db_session=db_session1, - ).reranked_sections + query_to_retrieve = state["query_to_retrieve"] + + documents = SearchPipeline( + search_request=SearchRequest( + query=query_to_retrieve, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=state["db_session"], + ).reranked_sections print(f"retrieved documents: {len(documents)}") return DocRetrievalOutput( - retrieved_documents=documents, + retrieved_documents=documents[:4], ) diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index a3eeba29fb..229a980762 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -32,6 +32,21 @@ \n\n Answer:""" +SUB_CHECK_PROMPT = """ \n + Your task is to see whether a given answer addresses a given question. + Please do not use any internal knowledge you may have - just focus on whether the answer + as given seems to address the question as given. + Here is the question: + \n ------- \n + {question} + \n ------- \n + Here is the suggested answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no:""" + + BASE_CHECK_PROMPT = """ \n Please check whether 1) the suggested answer seems to fully address the original question AND 2)the original question requests a simple, factual answer, and there are no ambiguities, judgements, @@ -50,7 +65,8 @@ Please answer with yes or no:""" VERIFIER_PROMPT = """ \n - Please check whether the document seems to be relevant for the answer of the question. Please + Please check whether the document provided below seems to be relevant + to get an answer to the provided question. Please only answer with 'yes' or 'no' \n Here is the initial question: \n ------- \n From 1f88b60abd9b5d42864a0333da303eb30a20bd4b Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 14:05:51 -0800 Subject: [PATCH 05/78] Now using result objects --- .../onyx/agent_search/answer_query/edges.py | 16 ++++++++++ .../answer_query/graph_builder.py | 14 +++++---- .../answer_query/nodes/answer_check.py | 2 +- .../answer_query/nodes/answer_generation.py | 2 +- .../answer_query/nodes/format_answer.py | 2 +- .../onyx/agent_search/answer_query/states.py | 15 ++++------ .../agent_search/expanded_retrieval/edges.py | 2 +- .../expanded_retrieval/graph_builder.py | 11 ++++++- .../nodes/format_results.py | 15 ++++++++++ .../agent_search/expanded_retrieval/states.py | 29 +++++++++++++++---- backend/onyx/agent_search/main/edges.py | 2 +- .../main/nodes/generate_initial_answer.py | 2 +- 12 files changed, 84 insertions(+), 28 deletions(-) create mode 100644 backend/onyx/agent_search/answer_query/edges.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py diff --git a/backend/onyx/agent_search/answer_query/edges.py b/backend/onyx/agent_search/answer_query/edges.py new file mode 100644 index 0000000000..15f60f2bdf --- /dev/null +++ b/backend/onyx/agent_search/answer_query/edges.py @@ -0,0 +1,16 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput + + +def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: + return Send( + "expanded_retrieval", + ExpandedRetrievalInput( + **state, + starting_query=state["starting_query"], + ), + ) diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py index e52bfe28d6..53f647eac8 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -2,6 +2,7 @@ from langgraph.graph import START from langgraph.graph import StateGraph +from onyx.agent_search.answer_query.edges import send_to_expanded_retrieval from onyx.agent_search.answer_query.nodes.answer_check import answer_check from onyx.agent_search.answer_query.nodes.answer_generation import answer_generation from onyx.agent_search.answer_query.nodes.format_answer import format_answer @@ -24,7 +25,7 @@ def answer_query_graph_builder() -> StateGraph: expanded_retrieval = expanded_retrieval_graph_builder().compile() graph.add_node( - node="expanded_retrieval_for_initial_decomp", + node="decomped_expanded_retrieval", action=expanded_retrieval, ) graph.add_node( @@ -42,12 +43,13 @@ def answer_query_graph_builder() -> StateGraph: ### Add edges ### - graph.add_edge( - start_key=START, - end_key="expanded_retrieval_for_initial_decomp", + graph.add_conditional_edges( + source=START, + path=send_to_expanded_retrieval, + path_map=["decomped_expanded_retrieval"], ) graph.add_edge( - start_key="expanded_retrieval_for_initial_decomp", + start_key="decomped_expanded_retrieval", end_key="answer_generation", ) graph.add_edge( @@ -83,7 +85,7 @@ def answer_query_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - query_to_answer="Who made Excel?", + question_to_answer="Who made Excel?", ) output = compiled_graph.invoke( input=inputs, diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py index 9505967306..f06b2071f9 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -10,7 +10,7 @@ def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( - question=state["query_to_answer"], + question=state["question_to_answer"], base_answer=state["answer"], ) ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index 18c0862e23..3de9c403d2 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -8,7 +8,7 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["query_to_answer"] + query = state["question_to_answer"] docs = state["reranked_documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 51f7dbad5b..061000701b 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -7,7 +7,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( decomp_answer_results=[ SearchAnswerResults( - query=state["query_to_answer"], + question=state["question_to_answer"], quality=state["answer_quality"], answer=state["answer"], documents=state["reranked_documents"], diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index d2dd1f12c6..8c24623ee0 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -4,14 +4,16 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.expanded_retrieval.states import RetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection class SearchAnswerResults(BaseModel): - query: str + question: str answer: str quality: str + retrieval_results: list[RetrievalResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -23,23 +25,18 @@ class QAGenerationOutput(TypedDict, total=False): answer: str -class ExpandedRetrievalOutput(TypedDict): - reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - - class AnswerQueryState( PrimaryState, QACheckOutput, QAGenerationOutput, - ExpandedRetrievalOutput, total=True, ): - query_to_answer: str + question: str class AnswerQueryInput(PrimaryState, total=True): - query_to_answer: str + question: str class AnswerQueryOutput(TypedDict): - decomp_answer_results: list[SearchAnswerResults] + answer_results: list[SearchAnswerResults] diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 063befe85a..085479a4e4 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -14,7 +14,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab print(f"parallel_retrieval_edge state: {state.keys()}") # This should be better... - question = state.get("query_to_answer") or state["search_request"].query + question = state.get("question_to_answer") or state["search_request"].query llm: LLM = state["fast_llm"] msg = [ diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 1928e93450..7f94c1ef78 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -8,6 +8,7 @@ from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( doc_verification, ) +from onyx.agent_search.expanded_retrieval.nodes.format_results import format_results from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( verification_kickoff, ) @@ -41,6 +42,10 @@ def expanded_retrieval_graph_builder() -> StateGraph: node="doc_reranking", action=doc_reranking, ) + graph.add_node( + node="format_results", + action=format_results, + ) ### Add edges ### @@ -59,6 +64,10 @@ def expanded_retrieval_graph_builder() -> StateGraph: ) graph.add_edge( start_key="doc_reranking", + end_key="format_results", + ) + graph.add_edge( + start_key="format_results", end_key=END, ) @@ -82,7 +91,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - query_to_answer="Who made Excel?", + question_to_answer="Who made Excel?", ) for thing in compiled_graph.stream(inputs, debug=True): print(thing) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py new file mode 100644 index 0000000000..36883eb6bd --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -0,0 +1,15 @@ +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import RetrievalResult + + +def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: + return ExpandedRetrievalOutput( + retrieval_results=[ + RetrievalResult( + starting_query=state["starting_query"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["reranked_documents"], + ) + ], + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 54fa6023cc..697639cc4a 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -1,12 +1,29 @@ +from operator import add from typing import Annotated from typing import TypedDict +from pydantic import BaseModel + from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection +class ExpandedRetrievalResult(BaseModel): + expanded_query: str + expanded_retrieval_documents: Annotated[ + list[InferenceSection], dedup_inference_sections + ] + + +class RetrievalResult(BaseModel): + starting_query: str + expanded_retrieval_results: list[ExpandedRetrievalResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + class DocRetrievalOutput(TypedDict, total=False): + expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -18,6 +35,10 @@ class DocRerankingOutput(TypedDict, total=False): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class ExpandedRetrievalOutput(TypedDict): + retrieval_results: Annotated[list[RetrievalResult], add] + + class ExpandedRetrievalState( PrimaryState, DocRetrievalOutput, @@ -25,12 +46,8 @@ class ExpandedRetrievalState( DocRerankingOutput, total=True, ): - query_to_answer: str + starting_query: str class ExpandedRetrievalInput(PrimaryState, total=True): - query_to_answer: str - - -class ExpandedRetrievalOutput(TypedDict): - reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + starting_query: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 953b0a9627..0ec4c0f4a6 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -12,7 +12,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha "answer_query", AnswerQueryInput( **state, - query_to_answer=query, + question_to_answer=query, ), ) for query in state["initial_decomp_queries"] diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 5671b2352f..a6476477ae 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -27,7 +27,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: ): good_qa_list.append( _SUB_QUESTION_ANSWER_TEMPLATE.format( - sub_question=decomp_answer_result.query, + sub_question=decomp_answer_result.question, sub_answer=decomp_answer_result.answer, ) ) From 2f2b9a862ace2d637192fadbb1bb30deba1c29cb Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 15:11:54 -0800 Subject: [PATCH 06/78] fixed expanded retrieval subgraph --- backend/onyx/agent_search/answer_query/edges.py | 4 ++-- .../onyx/agent_search/answer_query/graph_builder.py | 4 ++-- .../agent_search/answer_query/nodes/answer_check.py | 2 +- .../answer_query/nodes/answer_generation.py | 4 ++-- .../agent_search/answer_query/nodes/format_answer.py | 5 +++-- backend/onyx/agent_search/answer_query/states.py | 6 ++++-- backend/onyx/agent_search/expanded_retrieval/edges.py | 2 +- .../agent_search/expanded_retrieval/graph_builder.py | 10 +++++++--- .../expanded_retrieval/nodes/doc_retrieval.py | 6 ++++++ .../expanded_retrieval/nodes/format_results.py | 10 ++-------- .../onyx/agent_search/expanded_retrieval/states.py | 11 ++++++----- backend/onyx/agent_search/main/edges.py | 2 +- 12 files changed, 37 insertions(+), 29 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/edges.py b/backend/onyx/agent_search/answer_query/edges.py index 15f60f2bdf..c538ef8958 100644 --- a/backend/onyx/agent_search/answer_query/edges.py +++ b/backend/onyx/agent_search/answer_query/edges.py @@ -8,9 +8,9 @@ def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: return Send( - "expanded_retrieval", + "decomped_expanded_retrieval", ExpandedRetrievalInput( **state, - starting_query=state["starting_query"], + starting_query=state["question"], ), ) diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py index 53f647eac8..27d89af084 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -77,7 +77,7 @@ def answer_query_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="Who made Excel and what other products did they make?", + query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: inputs = AnswerQueryInput( @@ -85,7 +85,7 @@ def answer_query_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - question_to_answer="Who made Excel?", + question="what can you do with onyx?", ) output = compiled_graph.invoke( input=inputs, diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py index f06b2071f9..c035f309fe 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -10,7 +10,7 @@ def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( - question=state["question_to_answer"], + question=state["question"], base_answer=state["answer"], ) ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index 3de9c403d2..d35d55673a 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -8,8 +8,8 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["question_to_answer"] - docs = state["reranked_documents"] + query = state["question"] + docs = state["documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 061000701b..5a7fffddaf 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -7,10 +7,11 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( decomp_answer_results=[ SearchAnswerResults( - question=state["question_to_answer"], + question=state["question"], quality=state["answer_quality"], answer=state["answer"], - documents=state["reranked_documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], ) ], ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index 8c24623ee0..f0249b4fe7 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -4,7 +4,8 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import RetrievalResult +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -13,7 +14,7 @@ class SearchAnswerResults(BaseModel): question: str answer: str quality: str - retrieval_results: list[RetrievalResult] + expanded_retrieval_results: list[ExpandedRetrievalResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -29,6 +30,7 @@ class AnswerQueryState( PrimaryState, QACheckOutput, QAGenerationOutput, + ExpandedRetrievalOutput, total=True, ): question: str diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 085479a4e4..19a321bd72 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -14,7 +14,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab print(f"parallel_retrieval_edge state: {state.keys()}") # This should be better... - question = state.get("question_to_answer") or state["search_request"].query + question = state.get("question") or state["search_request"].query llm: LLM = state["fast_llm"] msg = [ diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 7f94c1ef78..c2bfd1e346 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -83,7 +83,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="Who made Excel and what other products did they make?", + query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: inputs = ExpandedRetrievalInput( @@ -91,7 +91,11 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - question_to_answer="Who made Excel?", + question="what can you do with onyx?", ) - for thing in compiled_graph.stream(inputs, debug=True): + for thing in compiled_graph.stream( + input=inputs, + # debug=True, + subgraphs=True, + ): print(thing) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index af38e5f490..118aaa776c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,4 +1,5 @@ from onyx.agent_search.expanded_retrieval.states import DocRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest @@ -38,6 +39,11 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: ).reranked_sections print(f"retrieved documents: {len(documents)}") + expanded_retrieval_result = ExpandedRetrievalResult( + expanded_query=query_to_retrieve, + expanded_retrieval_documents=documents[:4], + ) return DocRetrievalOutput( + expanded_retrieval_results=[expanded_retrieval_result], retrieved_documents=documents[:4], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 36883eb6bd..2a9620a0a9 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -1,15 +1,9 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -from onyx.agent_search.expanded_retrieval.states import RetrievalResult def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: return ExpandedRetrievalOutput( - retrieval_results=[ - RetrievalResult( - starting_query=state["starting_query"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["reranked_documents"], - ) - ], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["reranked_documents"], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 697639cc4a..e6c3e8945e 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -16,10 +16,10 @@ class ExpandedRetrievalResult(BaseModel): ] -class RetrievalResult(BaseModel): - starting_query: str - expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] +# class RetrievalResult(BaseModel): +# starting_query: str +# expanded_retrieval_results: list[ExpandedRetrievalResult] +# documents: Annotated[list[InferenceSection], dedup_inference_sections] class DocRetrievalOutput(TypedDict, total=False): @@ -36,7 +36,8 @@ class DocRerankingOutput(TypedDict, total=False): class ExpandedRetrievalOutput(TypedDict): - retrieval_results: Annotated[list[RetrievalResult], add] + expanded_retrieval_results: list[ExpandedRetrievalResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] class ExpandedRetrievalState( diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 0ec4c0f4a6..caaf9fe412 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -12,7 +12,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha "answer_query", AnswerQueryInput( **state, - question_to_answer=query, + question=query, ), ) for query in state["initial_decomp_queries"] From 442c94727e92dfd39d6d0e1d90bc2d624c671f83 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 15:16:36 -0800 Subject: [PATCH 07/78] got answer subgraph working --- .../agent_search/answer_query/graph_builder.py | 15 +++++---------- .../answer_query/nodes/format_answer.py | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py index 27d89af084..1036f425b7 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -87,16 +87,11 @@ def answer_query_graph_builder() -> StateGraph: db_session=db_session, question="what can you do with onyx?", ) - output = compiled_graph.invoke( + for thing in compiled_graph.stream( input=inputs, # debug=True, # subgraphs=True, - ) - print(output) - # for namespace, chunk in compiled_graph.stream( - # input=inputs, - # # debug=True, - # subgraphs=True, - # ): - # print(namespace) - # print(chunk) + ): + print(thing) + # output = compiled_graph.invoke(inputs) + # print(output) diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 5a7fffddaf..4220c2cc1e 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -5,7 +5,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( - decomp_answer_results=[ + answer_results=[ SearchAnswerResults( question=state["question"], quality=state["answer_quality"], From d66180fe13ad48995cb32ec472b0da69353743ba Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 07:33:40 -0800 Subject: [PATCH 08/78] Cleanup --- .../answer_query/nodes/answer_generation.py | 4 ++-- .../answer_query/nodes/format_answer.py | 16 +++++++--------- backend/onyx/agent_search/answer_query/states.py | 10 ++++------ .../agent_search/expanded_retrieval/states.py | 6 ------ backend/onyx/agent_search/main/edges.py | 4 ++-- .../onyx/agent_search/main/nodes/base_decomp.py | 2 +- backend/onyx/agent_search/main/states.py | 2 +- 7 files changed, 17 insertions(+), 27 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index d35d55673a..f036267c3b 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -8,14 +8,14 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["question"] + question = state["question"] docs = state["documents"] print(f"Number of verified retrieval docs: {len(docs)}") msg = [ HumanMessage( - content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs)) + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) ) ] diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 4220c2cc1e..2bf618c571 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -5,13 +5,11 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( - answer_results=[ - SearchAnswerResults( - question=state["question"], - quality=state["answer_quality"], - answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - ) - ], + answer_result=SearchAnswerResults( + question=state["question"], + quality=state["answer_quality"], + answer=state["answer"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + ), ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index f0249b4fe7..7e6cf160a9 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -1,4 +1,3 @@ -from typing import Annotated from typing import TypedDict from pydantic import BaseModel @@ -6,7 +5,6 @@ from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -15,7 +13,7 @@ class SearchAnswerResults(BaseModel): answer: str quality: str expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] + documents: list[InferenceSection] class QACheckOutput(TypedDict, total=False): @@ -28,9 +26,9 @@ class QAGenerationOutput(TypedDict, total=False): class AnswerQueryState( PrimaryState, - QACheckOutput, - QAGenerationOutput, ExpandedRetrievalOutput, + QAGenerationOutput, + QACheckOutput, total=True, ): question: str @@ -41,4 +39,4 @@ class AnswerQueryInput(PrimaryState, total=True): class AnswerQueryOutput(TypedDict): - answer_results: list[SearchAnswerResults] + answer_result: SearchAnswerResults diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index e6c3e8945e..238b294a79 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -16,12 +16,6 @@ class ExpandedRetrievalResult(BaseModel): ] -# class RetrievalResult(BaseModel): -# starting_query: str -# expanded_retrieval_results: list[ExpandedRetrievalResult] -# documents: Annotated[list[InferenceSection], dedup_inference_sections] - - class DocRetrievalOutput(TypedDict, total=False): expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index caaf9fe412..7d58dbdacd 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -12,10 +12,10 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha "answer_query", AnswerQueryInput( **state, - question=query, + question=question, ), ) - for query in state["initial_decomp_queries"] + for question in state["initial_decomp_questions"] ] diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py index 28e93c6cbc..e8af64a9fb 100644 --- a/backend/onyx/agent_search/main/nodes/base_decomp.py +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -27,5 +27,5 @@ def main_decomp_base(state: MainState) -> BaseDecompOutput: ] return BaseDecompOutput( - initial_decomp_queries=decomp_list, + initial_decomp_questions=decomp_list, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 3b753ff847..1c3fcc41d1 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -9,7 +9,7 @@ class BaseDecompOutput(TypedDict, total=False): - initial_decomp_queries: list[str] + initial_decomp_questions: list[str] class InitialAnswerOutput(TypedDict, total=False): From e76cbec53cd78fac6c5719e1a82690de5516eaed Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 08:43:54 -0800 Subject: [PATCH 09/78] main graph works --- .../answer_query/nodes/format_answer.py | 16 +++++++++------- backend/onyx/agent_search/answer_query/states.py | 10 +++++++++- backend/onyx/agent_search/main/graph_builder.py | 11 ++++++++++- .../agent_search/main/nodes/ingest_answers.py | 15 +++++++++++++++ backend/onyx/agent_search/main/states.py | 16 ++++++++++------ 5 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 backend/onyx/agent_search/main/nodes/ingest_answers.py diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 2bf618c571..4220c2cc1e 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -5,11 +5,13 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( - answer_result=SearchAnswerResults( - question=state["question"], - quality=state["answer_quality"], - answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - ), + answer_results=[ + SearchAnswerResults( + question=state["question"], + quality=state["answer_quality"], + answer=state["answer"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + ) + ], ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index 7e6cf160a9..f622db8215 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -1,3 +1,5 @@ +from operator import add +from typing import Annotated from typing import TypedDict from pydantic import BaseModel @@ -39,4 +41,10 @@ class AnswerQueryInput(PrimaryState, total=True): class AnswerQueryOutput(TypedDict): - answer_result: SearchAnswerResults + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + answer_results: Annotated[list[SearchAnswerResults], add] diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 930d7a745f..dbe02194a1 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -11,6 +11,7 @@ from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, ) +from onyx.agent_search.main.nodes.ingest_answers import ingest_answers from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -41,6 +42,10 @@ def main_graph_builder() -> StateGraph: node="generate_initial_answer", action=generate_initial_answer, ) + graph.add_node( + node="ingest_answers", + action=ingest_answers, + ) ### Add edges ### graph.add_edge( @@ -59,6 +64,10 @@ def main_graph_builder() -> StateGraph: ) graph.add_edge( start_key=["answer_query", "expanded_retrieval"], + end_key="ingest_answers", + ) + graph.add_edge( + start_key="ingest_answers", end_key="generate_initial_answer", ) graph.add_edge( @@ -78,7 +87,7 @@ def main_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="If i am familiar with the function that I need, how can I type it into a cell?", + query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: inputs = MainInput( diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py new file mode 100644 index 0000000000..f761a85b1a --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -0,0 +1,15 @@ +from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.main.states import DecompAnswersOutput +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections + + +def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersOutput: + documents = [] + for answer_result in state["answer_results"]: + documents.extend(answer_result.documents) + return DecompAnswersOutput( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + documents=dedup_inference_sections(documents, []), + decomp_answer_results=state["answer_results"].answer_results, + ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 1c3fcc41d1..c28220a967 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -8,25 +8,29 @@ from onyx.context.search.models import InferenceSection -class BaseDecompOutput(TypedDict, total=False): +class BaseDecompOutput(TypedDict): initial_decomp_questions: list[str] -class InitialAnswerOutput(TypedDict, total=False): +class InitialAnswerOutput(TypedDict): initial_answer: str +class DecompAnswersOutput(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: Annotated[list[SearchAnswerResults], add] + + class MainState( PrimaryState, BaseDecompOutput, InitialAnswerOutput, - total=True, + DecompAnswersOutput, ): - documents: Annotated[list[InferenceSection], dedup_inference_sections] - decomp_answer_results: Annotated[list[SearchAnswerResults], add] + pass -class MainInput(PrimaryState, total=True): +class MainInput(PrimaryState): pass From fd694bea8faab8b2693908f483cbdec2a8eb546f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 08:47:43 -0800 Subject: [PATCH 10/78] query->question --- .../{answer_query => answer_question}/edges.py | 2 +- .../graph_builder.py | 14 +++++++------- .../nodes/answer_check.py | 4 ++-- .../nodes/answer_generation.py | 4 ++-- .../nodes/format_answer.py | 6 +++--- .../{answer_query => answer_question}/states.py | 0 backend/onyx/agent_search/main/edges.py | 2 +- backend/onyx/agent_search/main/graph_builder.py | 2 +- .../onyx/agent_search/main/nodes/ingest_answers.py | 2 +- backend/onyx/agent_search/main/states.py | 2 +- 10 files changed, 19 insertions(+), 19 deletions(-) rename backend/onyx/agent_search/{answer_query => answer_question}/edges.py (85%) rename backend/onyx/agent_search/{answer_query => answer_question}/graph_builder.py (81%) rename backend/onyx/agent_search/{answer_query => answer_question}/nodes/answer_check.py (83%) rename backend/onyx/agent_search/{answer_query => answer_question}/nodes/answer_generation.py (85%) rename backend/onyx/agent_search/{answer_query => answer_question}/nodes/format_answer.py (67%) rename backend/onyx/agent_search/{answer_query => answer_question}/states.py (100%) diff --git a/backend/onyx/agent_search/answer_query/edges.py b/backend/onyx/agent_search/answer_question/edges.py similarity index 85% rename from backend/onyx/agent_search/answer_query/edges.py rename to backend/onyx/agent_search/answer_question/edges.py index c538ef8958..45c24137d3 100644 --- a/backend/onyx/agent_search/answer_query/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQueryInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py similarity index 81% rename from backend/onyx/agent_search/answer_query/graph_builder.py rename to backend/onyx/agent_search/answer_question/graph_builder.py index 1036f425b7..75563e6abf 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -2,13 +2,13 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_query.edges import send_to_expanded_retrieval -from onyx.agent_search.answer_query.nodes.answer_check import answer_check -from onyx.agent_search.answer_query.nodes.answer_generation import answer_generation -from onyx.agent_search.answer_query.nodes.format_answer import format_answer -from onyx.agent_search.answer_query.states import AnswerQueryInput -from onyx.agent_search.answer_query.states import AnswerQueryOutput -from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_question.edges import send_to_expanded_retrieval +from onyx.agent_search.answer_question.nodes.answer_check import answer_check +from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation +from onyx.agent_search.answer_question.nodes.format_answer import format_answer +from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQueryState from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py similarity index 83% rename from backend/onyx/agent_search/answer_query/nodes/answer_check.py rename to backend/onyx/agent_search/answer_question/nodes/answer_check.py index c035f309fe..008001b620 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -1,8 +1,8 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_query.states import AnswerQueryState -from onyx.agent_search.answer_query.states import QACheckOutput +from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import QACheckOutput from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py similarity index 85% rename from backend/onyx/agent_search/answer_query/nodes/answer_generation.py rename to backend/onyx/agent_search/answer_question/nodes/answer_generation.py index f036267c3b..f01f5baeac 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -1,8 +1,8 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_query.states import AnswerQueryState -from onyx.agent_search.answer_query.states import QAGenerationOutput +from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import QAGenerationOutput from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py similarity index 67% rename from backend/onyx/agent_search/answer_query/nodes/format_answer.py rename to backend/onyx/agent_search/answer_question/nodes/format_answer.py index 4220c2cc1e..fff4f940db 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -1,6 +1,6 @@ -from onyx.agent_search.answer_query.states import AnswerQueryOutput -from onyx.agent_search.answer_query.states import AnswerQueryState -from onyx.agent_search.answer_query.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import SearchAnswerResults def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_question/states.py similarity index 100% rename from backend/onyx/agent_search/answer_query/states.py rename to backend/onyx/agent_search/answer_question/states.py diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 7d58dbdacd..4f2468b8c4 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQueryInput from onyx.agent_search.main.states import MainState diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index dbe02194a1..d91aecc598 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -2,7 +2,7 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_query.graph_builder import answer_query_graph_builder +from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index f761a85b1a..8a59afdbaf 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -1,4 +1,4 @@ -from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQueryOutput from onyx.agent_search.main.states import DecompAnswersOutput from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index c28220a967..679230cd32 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -2,7 +2,7 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_query.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import SearchAnswerResults from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection From 8399d2ee0aab2bd65a7fd40ab47cda3ea18717bf Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 09:27:47 -0800 Subject: [PATCH 11/78] mypy fixed --- .../agent_search/answer_question/edges.py | 3 +- backend/onyx/agent_search/core_state.py | 12 ++++++++ .../expanded_retrieval/graph_builder.py | 2 +- backend/onyx/agent_search/main/edges.py | 3 +- .../onyx/agent_search/main/graph_builder.py | 28 +++++++++++++------ .../agent_search/main/nodes/base_decomp.py | 6 ++-- .../main/nodes/generate_initial_answer.py | 10 ++++--- .../agent_search/main/nodes/ingest_answers.py | 8 +++--- .../main/nodes/ingest_initial_retrieval.py | 9 ++++++ backend/onyx/agent_search/main/states.py | 21 ++++++++++---- 10 files changed, 74 insertions(+), 28 deletions(-) create mode 100644 backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 45c24137d3..05de589900 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -3,6 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput @@ -10,7 +11,7 @@ def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( - **state, + **extract_primary_fields(state), starting_query=state["question"], ), ) diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index fcd8bddf3e..ee490e0a33 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -1,4 +1,5 @@ from typing import TypedDict +from typing import TypeVar from sqlalchemy.orm import Session @@ -13,3 +14,14 @@ class PrimaryState(TypedDict, total=False): # a single session for the entire agent search # is fine if we are only reading db_session: Session + + +# This ensures that the state passed in extends the PrimaryState +T = TypeVar("T", bound=PrimaryState) + + +def extract_primary_fields(state: T) -> PrimaryState: + filtered_dict = { + k: v for k, v in state.items() if k in PrimaryState.__annotations__ + } + return PrimaryState(**dict(filtered_dict)) # type: ignore diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index c2bfd1e346..a160678196 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -91,7 +91,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - question="what can you do with onyx?", + starting_query="what can you do with onyx?", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 4f2468b8c4..c0730c1537 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -3,6 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.main.states import MainState @@ -11,7 +12,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha Send( "answer_query", AnswerQueryInput( - **state, + **extract_primary_fields(state), question=question, ), ) diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index d91aecc598..971398f9c9 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -12,6 +12,9 @@ generate_initial_answer, ) from onyx.agent_search.main.nodes.ingest_answers import ingest_answers +from onyx.agent_search.main.nodes.ingest_initial_retrieval import ( + ingest_initial_retrieval, +) from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -35,22 +38,30 @@ def main_graph_builder() -> StateGraph: ) expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() graph.add_node( - node="expanded_retrieval", + node="initial_retrieval", action=expanded_retrieval_subgraph, ) - graph.add_node( - node="generate_initial_answer", - action=generate_initial_answer, - ) graph.add_node( node="ingest_answers", action=ingest_answers, ) + graph.add_node( + node="ingest_initial_retrieval", + action=ingest_initial_retrieval, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) ### Add edges ### graph.add_edge( start_key=START, - end_key="expanded_retrieval", + end_key="initial_retrieval", + ) + graph.add_edge( + start_key="initial_retrieval", + end_key="ingest_initial_retrieval", ) graph.add_edge( @@ -63,11 +74,12 @@ def main_graph_builder() -> StateGraph: path_map=["answer_query"], ) graph.add_edge( - start_key=["answer_query", "expanded_retrieval"], + start_key="answer_query", end_key="ingest_answers", ) + graph.add_edge( - start_key="ingest_answers", + start_key=["ingest_answers", "ingest_initial_retrieval"], end_key="generate_initial_answer", ) graph.add_edge( diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py index e8af64a9fb..05b095794b 100644 --- a/backend/onyx/agent_search/main/nodes/base_decomp.py +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -1,12 +1,12 @@ from langchain_core.messages import HumanMessage -from onyx.agent_search.main.states import BaseDecompOutput +from onyx.agent_search.main.states import BaseDecompUpdate from onyx.agent_search.main.states import MainState from onyx.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string -def main_decomp_base(state: MainState) -> BaseDecompOutput: +def main_decomp_base(state: MainState) -> BaseDecompUpdate: question = state["search_request"].query msg = [ @@ -26,6 +26,6 @@ def main_decomp_base(state: MainState) -> BaseDecompOutput: sub_question["sub_question"].strip() for sub_question in list_of_subquestions ] - return BaseDecompOutput( + return BaseDecompUpdate( initial_decomp_questions=decomp_list, ) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index a6476477ae..828472d6ea 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -1,16 +1,18 @@ from langchain_core.messages import HumanMessage -from onyx.agent_search.main.states import InitialAnswerOutput +from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs -def generate_initial_answer(state: MainState) -> InitialAnswerOutput: +def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print("---GENERATE INITIAL---") question = state["search_request"].query docs = state["documents"] + all_original_question_documents = state["all_original_question_documents"] + combined_docs = docs + all_original_question_documents decomp_answer_results = state["decomp_answer_results"] @@ -38,7 +40,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: HumanMessage( content=INITIAL_RAG_PROMPT.format( question=question, - context=format_docs(docs), + context=format_docs(combined_docs), answered_sub_questions=sub_question_answer_str, ) ) @@ -50,4 +52,4 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: answer = response.pretty_repr() print(answer) - return InitialAnswerOutput(initial_answer=answer) + return InitialAnswerUpdate(initial_answer=answer) diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index 8a59afdbaf..2662951cce 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -1,15 +1,15 @@ from onyx.agent_search.answer_question.states import AnswerQueryOutput -from onyx.agent_search.main.states import DecompAnswersOutput +from onyx.agent_search.main.states import DecompAnswersUpdate from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersOutput: +def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersUpdate: documents = [] for answer_result in state["answer_results"]: documents.extend(answer_result.documents) - return DecompAnswersOutput( + return DecompAnswersUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here documents=dedup_inference_sections(documents, []), - decomp_answer_results=state["answer_results"].answer_results, + decomp_answer_results=state["answer_results"], ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py new file mode 100644 index 0000000000..e3a96e0b8e --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -0,0 +1,9 @@ +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.main.states import ExpandedRetrievalUpdate + + +def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrievalUpdate: + return ExpandedRetrievalUpdate( + all_original_question_documents=state["documents"], + original_question_retrieval_results=state["expanded_retrieval_results"], + ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 679230cd32..7e2c14d2bc 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -4,28 +4,37 @@ from onyx.agent_search.answer_question.states import SearchAnswerResults from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection -class BaseDecompOutput(TypedDict): +class BaseDecompUpdate(TypedDict): initial_decomp_questions: list[str] -class InitialAnswerOutput(TypedDict): +class InitialAnswerUpdate(TypedDict): initial_answer: str -class DecompAnswersOutput(TypedDict): +class DecompAnswersUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] decomp_answer_results: Annotated[list[SearchAnswerResults], add] +class ExpandedRetrievalUpdate(TypedDict): + all_original_question_documents: Annotated[ + list[InferenceSection], dedup_inference_sections + ] + original_question_retrieval_results: list[ExpandedRetrievalResult] + + class MainState( PrimaryState, - BaseDecompOutput, - InitialAnswerOutput, - DecompAnswersOutput, + BaseDecompUpdate, + InitialAnswerUpdate, + DecompAnswersUpdate, + ExpandedRetrievalUpdate, ): pass From 50a216f554459d9c67c267b908d92018e1e66742 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 09:56:34 -0800 Subject: [PATCH 12/78] naming and comments --- .../agent_search/answer_question/edges.py | 4 +- .../answer_question/graph_builder.py | 14 +++--- .../answer_question/nodes/answer_check.py | 8 +-- .../nodes/answer_generation.py | 8 +-- .../answer_question/nodes/format_answer.py | 8 +-- .../agent_search/answer_question/states.py | 31 +++++++++--- .../agent_search/expanded_retrieval/edges.py | 8 ++- .../expanded_retrieval/nodes/doc_reranking.py | 6 +-- .../expanded_retrieval/nodes/doc_retrieval.py | 23 ++++----- .../nodes/doc_verification.py | 6 +-- .../agent_search/expanded_retrieval/states.py | 49 +++++++++++++------ backend/onyx/agent_search/main/edges.py | 4 +- .../agent_search/main/nodes/ingest_answers.py | 4 +- backend/onyx/agent_search/main/states.py | 13 +++++ 14 files changed, 116 insertions(+), 70 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 05de589900..ec32f1c852 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -2,12 +2,12 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: +def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index 75563e6abf..291d9deb6a 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -6,9 +6,9 @@ from onyx.agent_search.answer_question.nodes.answer_check import answer_check from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation from onyx.agent_search.answer_question.nodes.format_answer import format_answer -from onyx.agent_search.answer_question.states import AnswerQueryInput -from onyx.agent_search.answer_question.states import AnswerQueryOutput -from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) @@ -16,9 +16,9 @@ def answer_query_graph_builder() -> StateGraph: graph = StateGraph( - state_schema=AnswerQueryState, - input=AnswerQueryInput, - output=AnswerQueryOutput, + state_schema=AnswerQuestionState, + input=AnswerQuestionInput, + output=AnswerQuestionOutput, ) ### Add nodes ### @@ -80,7 +80,7 @@ def answer_query_graph_builder() -> StateGraph: query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: - inputs = AnswerQueryInput( + inputs = AnswerQuestionInput( search_request=search_request, primary_llm=primary_llm, fast_llm=fast_llm, diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py index 008001b620..b04953dd0b 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -1,12 +1,12 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQueryState -from onyx.agent_search.answer_question.states import QACheckOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QACheckUpdate from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT -def answer_check(state: AnswerQueryState) -> QACheckOutput: +def answer_check(state: AnswerQuestionState) -> QACheckUpdate: msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( @@ -25,6 +25,6 @@ def answer_check(state: AnswerQueryState) -> QACheckOutput: response_str = merge_message_runs(response, chunk_separator="")[0].content - return QACheckOutput( + return QACheckUpdate( answer_quality=response_str, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index f01f5baeac..d47d1aaf77 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -1,13 +1,13 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQueryState -from onyx.agent_search.answer_question.states import QAGenerationOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QAGenerationUpdate from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs -def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: +def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question = state["question"] docs = state["documents"] @@ -27,6 +27,6 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: ) answer_str = merge_message_runs(response, chunk_separator="")[0].content - return QAGenerationOutput( + return QAGenerationUpdate( answer=answer_str, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index fff4f940db..216100a94c 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -1,10 +1,10 @@ -from onyx.agent_search.answer_question.states import AnswerQueryOutput -from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.answer_question.states import SearchAnswerResults -def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: - return AnswerQueryOutput( +def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: + return AnswerQuestionOutput( answer_results=[ SearchAnswerResults( question=state["question"], diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index f622db8215..06cbe3ba83 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -10,6 +10,9 @@ from onyx.context.search.models import InferenceSection +### Models ### + + class SearchAnswerResults(BaseModel): question: str answer: str @@ -18,29 +21,43 @@ class SearchAnswerResults(BaseModel): documents: list[InferenceSection] -class QACheckOutput(TypedDict, total=False): +### States ### + +## Update States + + +class QACheckUpdate(TypedDict): answer_quality: str -class QAGenerationOutput(TypedDict, total=False): +class QAGenerationUpdate(TypedDict): answer: str -class AnswerQueryState( +## Graph State + + +class AnswerQuestionState( PrimaryState, ExpandedRetrievalOutput, - QAGenerationOutput, - QACheckOutput, + QAGenerationUpdate, + QACheckUpdate, total=True, ): question: str -class AnswerQueryInput(PrimaryState, total=True): +## Input State + + +class AnswerQuestionInput(PrimaryState): question: str -class AnswerQueryOutput(TypedDict): +## Graph Output State + + +class AnswerQuestionOutput(TypedDict): """ This is a list of results even though each call of this subgraph only returns one result. This is because if we parallelize the answer query subgraph, there will be multiple diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 19a321bd72..1c62ba7dd9 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -4,7 +4,8 @@ from langchain_core.messages import merge_message_runs from langgraph.types import Send -from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput +from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.llm.interfaces import LLM @@ -38,7 +39,10 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab return [ Send( "doc_retrieval", - RetrieveInput(query_to_retrieve=query, **state), + RetrievalInput( + query_to_retrieve=query, + **extract_primary_fields(state), + ), ) for query in rewritten_queries ] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index 1ac3620351..925b7c7f44 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -1,11 +1,11 @@ -from onyx.agent_search.expanded_retrieval.states import DocRerankingOutput +from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput: +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: print(f"doc_reranking state: {state.keys()}") verified_documents = state["verified_documents"] reranked_documents = verified_documents - return DocRerankingOutput(reranked_documents=reranked_documents) + return DocRerankingUpdate(reranked_documents=reranked_documents) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 118aaa776c..a141bfcaac 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,34 +1,29 @@ -from onyx.agent_search.expanded_retrieval.states import DocRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import RetrievalInput from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline -class RetrieveInput(ExpandedRetrievalState): - query_to_retrieve: str - - -def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: +def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: """ Retrieve documents Args: - state (dict): The current graph state + state (RetrievalInput): Primary state + the query to retrieve - Returns: - state (dict): New key added to state, documents, that contains retrieved documents + Updates: + expanded_retrieval_results: list[ExpandedRetrievalResult] + retrieved_documents: list[InferenceSection] """ - print(f"doc_retrieval state: {state.keys()}") - documents: list[InferenceSection] = [] llm = state["primary_llm"] fast_llm = state["fast_llm"] query_to_retrieve = state["query_to_retrieve"] - documents = SearchPipeline( + documents: list[InferenceSection] = SearchPipeline( search_request=SearchRequest( query=query_to_retrieve, ), @@ -43,7 +38,7 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: expanded_query=query_to_retrieve, expanded_retrieval_documents=documents[:4], ) - return DocRetrievalOutput( + return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], retrieved_documents=documents[:4], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index f3f993e87b..741c445e2c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -1,7 +1,7 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.expanded_retrieval.states import DocVerificationOutput +from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.shared_graph_utils.models import BinaryDecision from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT @@ -12,7 +12,7 @@ class DocVerificationInput(ExpandedRetrievalState, total=True): doc_to_verify: InferenceSection -def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: +def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: """ Check whether the document is relevant for the original user question @@ -55,6 +55,6 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: if formatted_response.decision == "yes": verified_documents.append(doc_to_verify) - return DocVerificationOutput( + return DocVerificationUpdate( verified_documents=verified_documents, ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 238b294a79..68c0b5889b 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -9,40 +9,57 @@ from onyx.context.search.models import InferenceSection +### Models ### + + class ExpandedRetrievalResult(BaseModel): expanded_query: str - expanded_retrieval_documents: Annotated[ - list[InferenceSection], dedup_inference_sections - ] + expanded_retrieval_documents: list[InferenceSection] -class DocRetrievalOutput(TypedDict, total=False): - expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] - retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] +### States ### +## Update States -class DocVerificationOutput(TypedDict, total=False): +class DocVerificationUpdate(TypedDict): verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] -class DocRerankingOutput(TypedDict, total=False): +class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] -class ExpandedRetrievalOutput(TypedDict): - expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] +class DocRetrievalUpdate(TypedDict): + expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] + retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +## Graph State class ExpandedRetrievalState( PrimaryState, - DocRetrievalOutput, - DocVerificationOutput, - DocRerankingOutput, - total=True, + DocRetrievalUpdate, + DocVerificationUpdate, + DocRerankingUpdate, ): starting_query: str -class ExpandedRetrievalInput(PrimaryState, total=True): +## Graph Output State + + +class ExpandedRetrievalOutput(TypedDict): + expanded_retrieval_results: list[ExpandedRetrievalResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +## Input States + + +class ExpandedRetrievalInput(PrimaryState): starting_query: str + + +class RetrievalInput(PrimaryState): + query_to_retrieve: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index c0730c1537..484c0c354a 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.main.states import MainState @@ -11,7 +11,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha return [ Send( "answer_query", - AnswerQueryInput( + AnswerQuestionInput( **extract_primary_fields(state), question=question, ), diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index 2662951cce..c86f3f3104 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -1,9 +1,9 @@ -from onyx.agent_search.answer_question.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.main.states import DecompAnswersUpdate from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersUpdate: +def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: documents = [] for answer_result in state["answer_results"]: documents.extend(answer_result.documents) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 7e2c14d2bc..afed369e59 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -8,6 +8,10 @@ from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection +### States ### + +## Update States + class BaseDecompUpdate(TypedDict): initial_decomp_questions: list[str] @@ -29,6 +33,9 @@ class ExpandedRetrievalUpdate(TypedDict): original_question_retrieval_results: list[ExpandedRetrievalResult] +## Graph State + + class MainState( PrimaryState, BaseDecompUpdate, @@ -39,10 +46,16 @@ class MainState( pass +## Input States + + class MainInput(PrimaryState): pass +## Graph Output State + + class MainOutput(TypedDict): """ This is not used because defining the output only matters for filtering the output of From 9d3220fcfc3355531c51a7b92477357995c951f3 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 10:17:07 -0800 Subject: [PATCH 13/78] explicitly ingest state from retrieval --- .../onyx/agent_search/answer_question/graph_builder.py | 9 +++++++++ .../answer_question/nodes/ingest_retrieval.py | 9 +++++++++ backend/onyx/agent_search/answer_question/states.py | 9 +++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index 291d9deb6a..0aebb045de 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -6,6 +6,7 @@ from onyx.agent_search.answer_question.nodes.answer_check import answer_check from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation from onyx.agent_search.answer_question.nodes.format_answer import format_answer +from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import AnswerQuestionState @@ -40,6 +41,10 @@ def answer_query_graph_builder() -> StateGraph: node="format_answer", action=format_answer, ) + graph.add_node( + node="ingest_retrieval", + action=ingest_retrieval, + ) ### Add edges ### @@ -50,6 +55,10 @@ def answer_query_graph_builder() -> StateGraph: ) graph.add_edge( start_key="decomped_expanded_retrieval", + end_key="ingest_retrieval", + ) + graph.add_edge( + start_key="ingest_retrieval", end_key="answer_generation", ) graph.add_edge( diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py new file mode 100644 index 0000000000..7ee1ae75ef --- /dev/null +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -0,0 +1,9 @@ +from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput + + +def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: + return RetrievalIngestionUpdate( + documents=state["documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 06cbe3ba83..a0a4295da0 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -34,14 +34,19 @@ class QAGenerationUpdate(TypedDict): answer: str +class RetrievalIngestionUpdate(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + expanded_retrieval_results: list[ExpandedRetrievalResult] + + ## Graph State class AnswerQuestionState( PrimaryState, - ExpandedRetrievalOutput, QAGenerationUpdate, QACheckUpdate, + RetrievalIngestionUpdate, total=True, ): question: str From 0c75ca05799a77c179cf7398744c5610c97dcb7e Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 11:08:43 -0800 Subject: [PATCH 14/78] renames --- .../onyx/agent_search/answer_question/nodes/answer_check.py | 4 ++-- .../onyx/agent_search/answer_question/nodes/format_answer.py | 4 ++-- backend/onyx/agent_search/answer_question/states.py | 4 ++-- backend/onyx/agent_search/expanded_retrieval/states.py | 4 ++-- backend/onyx/agent_search/main/states.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py index b04953dd0b..83cc46280f 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -23,8 +23,8 @@ def answer_check(state: AnswerQuestionState) -> QACheckUpdate: ) ) - response_str = merge_message_runs(response, chunk_separator="")[0].content + quality_str = merge_message_runs(response, chunk_separator="")[0].content return QACheckUpdate( - answer_quality=response_str, + answer_quality=quality_str, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 216100a94c..c789729472 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -1,12 +1,12 @@ from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import QuestionAnswerResults def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: return AnswerQuestionOutput( answer_results=[ - SearchAnswerResults( + QuestionAnswerResults( question=state["question"], quality=state["answer_quality"], answer=state["answer"], diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index a0a4295da0..e216ac2049 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -13,7 +13,7 @@ ### Models ### -class SearchAnswerResults(BaseModel): +class QuestionAnswerResults(BaseModel): question: str answer: str quality: str @@ -69,4 +69,4 @@ class AnswerQuestionOutput(TypedDict): results in a list so the add operator is used to add them together. """ - answer_results: Annotated[list[SearchAnswerResults], add] + answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 68c0b5889b..81b96d95f1 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -43,7 +43,7 @@ class ExpandedRetrievalState( DocVerificationUpdate, DocRerankingUpdate, ): - starting_query: str + question: str ## Graph Output State @@ -58,7 +58,7 @@ class ExpandedRetrievalOutput(TypedDict): class ExpandedRetrievalInput(PrimaryState): - starting_query: str + question: str class RetrievalInput(PrimaryState): diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index afed369e59..a6c296109f 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -2,7 +2,7 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_question.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections @@ -23,7 +23,7 @@ class InitialAnswerUpdate(TypedDict): class DecompAnswersUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] - decomp_answer_results: Annotated[list[SearchAnswerResults], add] + decomp_answer_results: Annotated[list[QuestionAnswerResults], add] class ExpandedRetrievalUpdate(TypedDict): From bca02ebec6a0c24d97d39a509c530865a0998433 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 12:44:28 -0800 Subject: [PATCH 15/78] figured it out --- .../onyx/agent_search/answer_question/edges.py | 2 +- .../answer_question/nodes/ingest_retrieval.py | 6 ++++-- .../agent_search/answer_question/states.py | 7 +++---- .../expanded_retrieval/graph_builder.py | 2 +- .../expanded_retrieval/nodes/doc_reranking.py | 2 -- .../expanded_retrieval/nodes/doc_retrieval.py | 10 ++++------ .../nodes/doc_verification.py | 17 ++++------------- .../expanded_retrieval/nodes/format_results.py | 7 +++++-- .../nodes/verification_kickoff.py | 6 +++++- .../agent_search/expanded_retrieval/states.py | 18 +++++++++++++----- .../onyx/agent_search/main/graph_builder.py | 8 ++++---- .../main/nodes/ingest_initial_retrieval.py | 8 ++++++-- backend/onyx/agent_search/main/states.py | 4 ++-- 13 files changed, 52 insertions(+), 45 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index ec32f1c852..261821f8cc 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -12,6 +12,6 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: "decomped_expanded_retrieval", ExpandedRetrievalInput( **extract_primary_fields(state), - starting_query=state["question"], + question=state["question"], ), ) diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py index 7ee1ae75ef..f20ec7d86d 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -4,6 +4,8 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: return RetrievalIngestionUpdate( - documents=state["documents"], - expanded_retrieval_results=state["expanded_retrieval_results"], + expanded_retrieval_results=state[ + "expanded_retrieval_result" + ].expanded_queries_results, + documents=state["expanded_retrieval_result"].all_documents, ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index e216ac2049..898a035b7b 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -17,7 +17,7 @@ class QuestionAnswerResults(BaseModel): question: str answer: str quality: str - expanded_retrieval_results: list[ExpandedRetrievalResult] + expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] @@ -35,8 +35,8 @@ class QAGenerationUpdate(TypedDict): class RetrievalIngestionUpdate(TypedDict): + expanded_retrieval_results: list[QueryResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] - expanded_retrieval_results: list[ExpandedRetrievalResult] ## Graph State @@ -47,7 +47,6 @@ class AnswerQuestionState( QAGenerationUpdate, QACheckUpdate, RetrievalIngestionUpdate, - total=True, ): question: str diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index a160678196..c2bfd1e346 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -91,7 +91,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - starting_query="what can you do with onyx?", + question="what can you do with onyx?", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index 925b7c7f44..6f8e3df063 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -3,8 +3,6 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: - print(f"doc_reranking state: {state.keys()}") - verified_documents = state["verified_documents"] reranked_documents = verified_documents diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index a141bfcaac..c0b60ef38d 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,5 +1,5 @@ from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.expanded_retrieval.states import RetrievalInput from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest @@ -7,7 +7,6 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: - # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: """ Retrieve documents @@ -33,10 +32,9 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: db_session=state["db_session"], ).reranked_sections - print(f"retrieved documents: {len(documents)}") - expanded_retrieval_result = ExpandedRetrievalResult( - expanded_query=query_to_retrieve, - expanded_retrieval_documents=documents[:4], + expanded_retrieval_result = QueryResult( + query=query_to_retrieve, + documents_for_query=documents[:4], ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index 741c445e2c..3abebfcf2e 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -1,15 +1,10 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs +from onyx.agent_search.expanded_retrieval.states import DocVerificationInput from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.shared_graph_utils.models import BinaryDecision from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from onyx.context.search.models import InferenceSection - - -class DocVerificationInput(ExpandedRetrievalState, total=True): - doc_to_verify: InferenceSection def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: @@ -17,14 +12,12 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: Check whether the document is relevant for the original user question Args: - state (VerifierState): The current state + state (DocVerificationInput): The current state - Returns: - dict: ict: The updated state with the final decision + Updates: + verified_documents: list[InferenceSection] """ - print(f"doc_verification state: {state.keys()}") - original_query = state["search_request"].query doc_to_verify = state["doc_to_verify"] document_content = doc_to_verify.combined_content @@ -49,8 +42,6 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: decision_dict = {"decision": response_string.lower()} formatted_response = BinaryDecision.model_validate(decision_dict) - print(f"Verdict: {formatted_response.decision}") - verified_documents = [] if formatted_response.decision == "yes": verified_documents.append(doc_to_verify) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 2a9620a0a9..50da6e9a64 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -1,9 +1,12 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: return ExpandedRetrievalOutput( - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["reranked_documents"], + expanded_retrieval_result=ExpandedRetrievalResult( + expanded_queries_results=state["expanded_retrieval_results"], + all_documents=state["reranked_documents"], + ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index d40bf6f0da..0894088995 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -3,6 +3,7 @@ from langgraph.types import Command from langgraph.types import Send +from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( DocVerificationInput, ) @@ -20,7 +21,10 @@ def verification_kickoff( goto=[ Send( node="doc_verification", - arg=DocVerificationInput(doc_to_verify=doc, **state), + arg=DocVerificationInput( + doc_to_verify=doc, + **extract_primary_fields(state), + ), ) for doc in documents ], diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 81b96d95f1..5408e75e89 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -12,9 +12,14 @@ ### Models ### +class QueryResult(BaseModel): + query: str + documents_for_query: list[InferenceSection] + + class ExpandedRetrievalResult(BaseModel): - expanded_query: str - expanded_retrieval_documents: list[InferenceSection] + expanded_queries_results: list[QueryResult] + all_documents: list[InferenceSection] ### States ### @@ -30,7 +35,7 @@ class DocRerankingUpdate(TypedDict): class DocRetrievalUpdate(TypedDict): - expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] + expanded_retrieval_results: Annotated[list[QueryResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -50,8 +55,7 @@ class ExpandedRetrievalState( class ExpandedRetrievalOutput(TypedDict): - expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] + expanded_retrieval_result: ExpandedRetrievalResult ## Input States @@ -61,5 +65,9 @@ class ExpandedRetrievalInput(PrimaryState): question: str +class DocVerificationInput(PrimaryState): + doc_to_verify: InferenceSection + + class RetrievalInput(PrimaryState): query_to_retrieve: str diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 971398f9c9..f628ebf78c 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -41,14 +41,14 @@ def main_graph_builder() -> StateGraph: node="initial_retrieval", action=expanded_retrieval_subgraph, ) - graph.add_node( - node="ingest_answers", - action=ingest_answers, - ) graph.add_node( node="ingest_initial_retrieval", action=ingest_initial_retrieval, ) + graph.add_node( + node="ingest_answers", + action=ingest_answers, + ) graph.add_node( node="generate_initial_answer", action=generate_initial_answer, diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py index e3a96e0b8e..3cd75860ce 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -4,6 +4,10 @@ def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrievalUpdate: return ExpandedRetrievalUpdate( - all_original_question_documents=state["documents"], - original_question_retrieval_results=state["expanded_retrieval_results"], + original_question_retrieval_results=state[ + "expanded_retrieval_result" + ].expanded_queries_results, + all_original_question_documents=state[ + "expanded_retrieval_result" + ].all_documents, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index a6c296109f..081440344b 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -4,7 +4,7 @@ from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -30,7 +30,7 @@ class ExpandedRetrievalUpdate(TypedDict): all_original_question_documents: Annotated[ list[InferenceSection], dedup_inference_sections ] - original_question_retrieval_results: list[ExpandedRetrievalResult] + original_question_retrieval_results: list[QueryResult] ## Graph State From 2d6f7462592f0a4b0fea89578965a9a102edde4d Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 13:03:28 -0800 Subject: [PATCH 16/78] made query expansion explicit --- .../agent_search/expanded_retrieval/edges.py | 34 ++----------------- .../expanded_retrieval/graph_builder.py | 12 ++++++- .../nodes/expand_queries.py | 30 ++++++++++++++++ .../agent_search/expanded_retrieval/states.py | 5 +++ backend/onyx/agent_search/main/edges.py | 18 ++++++++-- .../onyx/agent_search/main/graph_builder.py | 14 +++----- 6 files changed, 70 insertions(+), 43 deletions(-) create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 1c62ba7dd9..d426ed3603 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -1,41 +1,13 @@ from collections.abc import Hashable -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs from langgraph.types import Send from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL -from onyx.llm.interfaces import LLM +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]: - print(f"parallel_retrieval_edge state: {state.keys()}") - - # This should be better... - question = state.get("question") or state["search_request"].query - llm: LLM = state["fast_llm"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), - ) - ] - llm_response_list = list( - llm.stream( - prompt=msg, - ) - ) - llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - - print(f"llm_response: {llm_response}") - - rewritten_queries = llm_response.split("--") - - print(f"rewritten_queries: {rewritten_queries}") - +def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]: return [ Send( "doc_retrieval", @@ -44,5 +16,5 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab **extract_primary_fields(state), ), ) - for query in rewritten_queries + for query in state["expanded_queries"] ] diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index c2bfd1e346..8da14eea43 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -8,6 +8,7 @@ from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( doc_verification, ) +from onyx.agent_search.expanded_retrieval.nodes.expand_queries import expand_queries from onyx.agent_search.expanded_retrieval.nodes.format_results import format_results from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( verification_kickoff, @@ -26,6 +27,11 @@ def expanded_retrieval_graph_builder() -> StateGraph: ### Add nodes ### + graph.add_node( + node="expand_queries", + action=expand_queries, + ) + graph.add_node( node="doc_retrieval", action=doc_retrieval, @@ -48,9 +54,13 @@ def expanded_retrieval_graph_builder() -> StateGraph: ) ### Add edges ### + graph.add_edge( + start_key=START, + end_key="expand_queries", + ) graph.add_conditional_edges( - source=START, + source="expand_queries", path=parallel_retrieval_edge, path_map=["doc_retrieval"], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py new file mode 100644 index 0000000000..193d9b648f --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py @@ -0,0 +1,30 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL +from onyx.llm.interfaces import LLM + + +def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: + question = state.get("question") + llm: LLM = state["fast_llm"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), + ) + ] + llm_response_list = list( + llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + rewritten_queries = llm_response.split("--") + + return QueryExpansionUpdate( + expanded_queries=rewritten_queries, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 5408e75e89..25160073e9 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -34,6 +34,10 @@ class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class QueryExpansionUpdate(TypedDict): + expanded_queries: list[str] + + class DocRetrievalUpdate(TypedDict): expanded_retrieval_results: Annotated[list[QueryResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -47,6 +51,7 @@ class ExpandedRetrievalState( DocRetrievalUpdate, DocVerificationUpdate, DocRerankingUpdate, + QueryExpansionUpdate, ): question: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 484c0c354a..454b245ee8 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -1,14 +1,18 @@ from collections.abc import Hashable +from collections.abc import Sequence from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainState -def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: - return [ +def parallelize_decompozed_answer_queries( + state: MainState, +) -> Sequence[Send | Hashable]: + answer_query_edges = [ Send( "answer_query", AnswerQuestionInput( @@ -18,6 +22,16 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha ) for question in state["initial_decomp_questions"] ] + initial_retrieval_edges = [ + Send( + "initial_retrieval", + ExpandedRetrievalInput( + **extract_primary_fields(state), + question=state["search_request"].query, + ), + ) + ] + return answer_query_edges + initial_retrieval_edges # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index f628ebf78c..0c85bac7c1 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -55,14 +55,6 @@ def main_graph_builder() -> StateGraph: ) ### Add edges ### - graph.add_edge( - start_key=START, - end_key="initial_retrieval", - ) - graph.add_edge( - start_key="initial_retrieval", - end_key="ingest_initial_retrieval", - ) graph.add_edge( start_key=START, @@ -71,7 +63,11 @@ def main_graph_builder() -> StateGraph: graph.add_conditional_edges( source="base_decomp", path=parallelize_decompozed_answer_queries, - path_map=["answer_query"], + path_map=["answer_query", "initial_retrieval"], + ) + graph.add_edge( + start_key="initial_retrieval", + end_key="ingest_initial_retrieval", ) graph.add_edge( start_key="answer_query", From ffc81f6e45b15c7bd2bfbdd476247a70831c20d9 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 13:10:46 -0800 Subject: [PATCH 17/78] seperate edge for initial retrieval --- backend/onyx/agent_search/main/edges.py | 14 +++++++------- backend/onyx/agent_search/main/graph_builder.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 454b245ee8..0836882ce6 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -1,18 +1,16 @@ from collections.abc import Hashable -from collections.abc import Sequence from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState -def parallelize_decompozed_answer_queries( - state: MainState, -) -> Sequence[Send | Hashable]: - answer_query_edges = [ +def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: + return [ Send( "answer_query", AnswerQuestionInput( @@ -22,7 +20,10 @@ def parallelize_decompozed_answer_queries( ) for question in state["initial_decomp_questions"] ] - initial_retrieval_edges = [ + + +def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: + return [ Send( "initial_retrieval", ExpandedRetrievalInput( @@ -31,7 +32,6 @@ def parallelize_decompozed_answer_queries( ), ) ] - return answer_query_edges + initial_retrieval_edges # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 0c85bac7c1..dc09000435 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -7,6 +7,7 @@ expanded_retrieval_graph_builder, ) from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries +from onyx.agent_search.main.edges import send_to_initial_retrieval from onyx.agent_search.main.nodes.base_decomp import main_decomp_base from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, @@ -56,6 +57,16 @@ def main_graph_builder() -> StateGraph: ### Add edges ### + graph.add_conditional_edges( + source=START, + path=send_to_initial_retrieval, + path_map=["initial_retrieval"], + ) + graph.add_edge( + start_key="initial_retrieval", + end_key="ingest_initial_retrieval", + ) + graph.add_edge( start_key=START, end_key="base_decomp", @@ -63,11 +74,7 @@ def main_graph_builder() -> StateGraph: graph.add_conditional_edges( source="base_decomp", path=parallelize_decompozed_answer_queries, - path_map=["answer_query", "initial_retrieval"], - ) - graph.add_edge( - start_key="initial_retrieval", - end_key="ingest_initial_retrieval", + path_map=["answer_query"], ) graph.add_edge( start_key="answer_query", From cebe237705121c66d490f3431a434b8c565f7f03 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Dec 2024 08:47:39 -0800 Subject: [PATCH 18/78] renamed PrimaryState to CoreState --- .../agent_search/answer_question/edges.py | 4 ++-- .../agent_search/answer_question/states.py | 20 +++++++++---------- backend/onyx/agent_search/core_state.py | 18 +++++++++-------- .../agent_search/expanded_retrieval/edges.py | 4 ++-- .../nodes/verification_kickoff.py | 4 ++-- backend/onyx/agent_search/main/edges.py | 6 +++--- backend/onyx/agent_search/main/states.py | 19 +++++++++--------- 7 files changed, 39 insertions(+), 36 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 261821f8cc..bdd9864e6e 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -3,7 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput @@ -11,7 +11,7 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( - **extract_primary_fields(state), + **extract_core_fields(state), question=state["question"], ), ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 898a035b7b..2964df0ab5 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -39,23 +39,23 @@ class RetrievalIngestionUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] +## Graph Input State + + +class AnswerQuestionInput(CoreState): + question: str + + ## Graph State class AnswerQuestionState( - PrimaryState, + AnswerQuestionInput, QAGenerationUpdate, QACheckUpdate, RetrievalIngestionUpdate, ): - question: str - - -## Input State - - -class AnswerQuestionInput(PrimaryState): - question: str + pass ## Graph Output State diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index ee490e0a33..cbc8f3d5c4 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -7,7 +7,11 @@ from onyx.llm.interfaces import LLM -class PrimaryState(TypedDict, total=False): +class CoreState(TypedDict, total=False): + """ + This is the core state that is shared across all subgraphs. + """ + search_request: SearchRequest primary_llm: LLM fast_llm: LLM @@ -16,12 +20,10 @@ class PrimaryState(TypedDict, total=False): db_session: Session -# This ensures that the state passed in extends the PrimaryState -T = TypeVar("T", bound=PrimaryState) +# This ensures that the state passed in extends the CoreState +T = TypeVar("T", bound=CoreState) -def extract_primary_fields(state: T) -> PrimaryState: - filtered_dict = { - k: v for k, v in state.items() if k in PrimaryState.__annotations__ - } - return PrimaryState(**dict(filtered_dict)) # type: ignore +def extract_core_fields(state: T) -> CoreState: + filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__} + return CoreState(**dict(filtered_dict)) # type: ignore diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index d426ed3603..61d994a687 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState @@ -13,7 +13,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab "doc_retrieval", RetrievalInput( query_to_retrieve=query, - **extract_primary_fields(state), + **extract_core_fields(state), ), ) for query in state["expanded_queries"] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index 0894088995..81fe7f9229 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -3,7 +3,7 @@ from langgraph.types import Command from langgraph.types import Send -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( DocVerificationInput, ) @@ -23,7 +23,7 @@ def verification_kickoff( node="doc_verification", arg=DocVerificationInput( doc_to_verify=doc, - **extract_primary_fields(state), + **extract_core_fields(state), ), ) for doc in documents diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 0836882ce6..7791498d3e 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -3,7 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -14,7 +14,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha Send( "answer_query", AnswerQuestionInput( - **extract_primary_fields(state), + **extract_core_fields(state), question=question, ), ) @@ -27,7 +27,7 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: Send( "initial_retrieval", ExpandedRetrievalInput( - **extract_primary_fields(state), + **extract_core_fields(state), question=state["search_request"].query, ), ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 081440344b..6fb44da490 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -3,7 +3,7 @@ from typing import TypedDict from onyx.agent_search.answer_question.states import QuestionAnswerResults -from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -33,11 +33,19 @@ class ExpandedRetrievalUpdate(TypedDict): original_question_retrieval_results: list[QueryResult] +## Graph Input State + + +class MainInput(CoreState): + pass + + ## Graph State class MainState( - PrimaryState, + # This includes the core state + MainInput, BaseDecompUpdate, InitialAnswerUpdate, DecompAnswersUpdate, @@ -46,13 +54,6 @@ class MainState( pass -## Input States - - -class MainInput(PrimaryState): - pass - - ## Graph Output State From 34aa054c5d1c1e51e88b651edb9a6c3ee2ec22f7 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Dec 2024 08:48:05 -0800 Subject: [PATCH 19/78] added chunk_ids and stats to QueryResult --- .../expanded_retrieval/nodes/doc_retrieval.py | 2 ++ .../agent_search/expanded_retrieval/states.py | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index c0b60ef38d..54d5421119 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -35,6 +35,8 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: expanded_retrieval_result = QueryResult( query=query_to_retrieve, documents_for_query=documents[:4], + chunk_ids=[], + stats={}, ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 25160073e9..71c845cd6f 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -1,10 +1,11 @@ from operator import add from typing import Annotated +from typing import Any from typing import TypedDict from pydantic import BaseModel -from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.core_state import CoreState from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -15,6 +16,8 @@ class QueryResult(BaseModel): query: str documents_for_query: list[InferenceSection] + chunk_ids: list[str] + stats: dict[str, Any] class ExpandedRetrievalResult(BaseModel): @@ -43,17 +46,25 @@ class DocRetrievalUpdate(TypedDict): retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] +## Graph Input State + + +class ExpandedRetrievalInput(CoreState): + question: str + + ## Graph State class ExpandedRetrievalState( - PrimaryState, + # This includes the core state + ExpandedRetrievalInput, DocRetrievalUpdate, DocVerificationUpdate, DocRerankingUpdate, QueryExpansionUpdate, ): - question: str + pass ## Graph Output State @@ -63,16 +74,12 @@ class ExpandedRetrievalOutput(TypedDict): expanded_retrieval_result: ExpandedRetrievalResult -## Input States - - -class ExpandedRetrievalInput(PrimaryState): - question: str +## Conditional Input States -class DocVerificationInput(PrimaryState): +class DocVerificationInput(CoreState): doc_to_verify: InferenceSection -class RetrievalInput(PrimaryState): +class RetrievalInput(CoreState): query_to_retrieve: str From 2a3328fc3d9da2cdc029eca184c209a334979de7 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 20 Dec 2024 14:41:20 -0800 Subject: [PATCH 20/78] initial_test_flow --- .gitignore | 4 +- .../nodes/answer_generation.py | 6 +- .../answer_question/nodes/format_answer.py | 1 + .../answer_question/nodes/ingest_retrieval.py | 3 + .../agent_search/answer_question/states.py | 7 ++ .../expanded_retrieval/nodes/doc_reranking.py | 5 +- .../expanded_retrieval/nodes/doc_retrieval.py | 56 ++++++++- .../nodes/format_results.py | 57 +++++++++ .../agent_search/expanded_retrieval/states.py | 3 +- .../onyx/agent_search/main/graph_builder.py | 26 +++- .../nodes/generate_initial_BASE_answer.py | 31 +++++ .../main/nodes/generate_initial_answer.py | 104 +++++++++++++++- .../main/nodes/ingest_initial_retrieval.py | 3 + backend/onyx/agent_search/main/states.py | 27 ++++- .../shared_graph_utils/calculations.py | 11 ++ .../shared_graph_utils/prompts.py | 58 +++++++-- .../regression/answer_quality/agent_test.py | 112 ++++++++++++++++++ 17 files changed, 479 insertions(+), 35 deletions(-) create mode 100644 backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/calculations.py create mode 100644 backend/tests/regression/answer_quality/agent_test.py diff --git a/.gitignore b/.gitignore index b97fb309d7..24739991f2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ .vscode/ *.sw? /backend/tests/regression/answer_quality/search_test_config.yaml -/web/test-results/ \ No newline at end of file +/web/test-results/ +backend/onyx/agent_search/main/test_data.json +backend/tests/regression/answer_quality/test_data.json diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index d47d1aaf77..7e1e326b68 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -15,7 +15,11 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: msg = [ HumanMessage( - content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) + content=BASE_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + original_question=state["search_request"].query, + ) ) ] diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index c789729472..95a0ac38bf 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -12,6 +12,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: answer=state["answer"], expanded_retrieval_results=state["expanded_retrieval_results"], documents=state["documents"], + sub_question_retrieval_stats=state["sub_question_retrieval_stats"], ) ], ) diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py index f20ec7d86d..54830b1873 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -8,4 +8,7 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate "expanded_retrieval_result" ].expanded_queries_results, documents=state["expanded_retrieval_result"].all_documents, + sub_question_retrieval_stats=state[ + "expanded_retrieval_result" + ].sub_question_retrieval_stats, ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 2964df0ab5..a58d4439be 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -13,12 +13,17 @@ ### Models ### +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] + + class QuestionAnswerResults(BaseModel): question: str answer: str quality: str expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] + sub_question_retrieval_stats: dict ### States ### @@ -32,11 +37,13 @@ class QACheckUpdate(TypedDict): class QAGenerationUpdate(TypedDict): answer: str + # answer_stat: AnswerStats class RetrievalIngestionUpdate(TypedDict): expanded_retrieval_results: list[QueryResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: dict ## Graph Input State diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index 6f8e3df063..f5a5bd3cd1 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -3,7 +3,10 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: + state["expanded_retrieval_results"] verified_documents = state["verified_documents"] reranked_documents = verified_documents - return DocRerankingUpdate(reranked_documents=reranked_documents) + return DocRerankingUpdate( + reranked_documents=reranked_documents, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 54d5421119..c5fec00fea 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,3 +1,7 @@ +from backend.onyx.agent_search.shared_graph_utils.calculations import ( + calculate_rank_shift, +) + from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.expanded_retrieval.states import RetrievalInput @@ -30,15 +34,57 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: llm=llm, fast_llm=fast_llm, db_session=state["db_session"], - ).reranked_sections + ) + + # Initial calculations of scores for the retrieval quality + + ranked_sections = { + "initial": documents.final_context_sections, + "reranked": documents.reranked_sections, + } + + fit_scores = {} + + for rank_type, docs in ranked_sections.items(): + fit_scores[rank_type] = {} + for i in [1, 5, 10]: + fit_scores[rank_type][i] = ( + sum([doc.center_chunk.score for doc in docs[:i]]) / i + ) + + fit_scores[rank_type]["fit_score"] = ( + 1 + / 3 + * ( + fit_scores[rank_type][1] + + fit_scores[rank_type][5] + + fit_scores[rank_type][10] + ) + ) + + fit_scores[rank_type]["fit_score"] = fit_scores[rank_type][1] + + fit_scores[rank_type]["chunk_ids"] = [doc.center_chunk.chunk_id for doc in docs] + + fit_score_lift = ( + fit_scores["reranked"]["fit_score"] / fit_scores["initial"]["fit_score"] + ) + + average_rank_change = calculate_rank_shift( + fit_scores["initial"]["chunk_ids"], fit_scores["reranked"]["chunk_ids"] + ) + + fit_scores["rerank_effect"] = average_rank_change + fit_scores["fit_score_lift"] = fit_score_lift + + documents = documents.reranked_sections[:4] expanded_retrieval_result = QueryResult( query=query_to_retrieve, - documents_for_query=documents[:4], - chunk_ids=[], - stats={}, + documents_for_query=documents, + stats=fit_scores, ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], - retrieved_documents=documents[:4], + retrieved_documents=documents, ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 50da6e9a64..cf1463d5b5 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -1,12 +1,69 @@ +from collections import defaultdict + +import numpy as np + from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import InferenceSection +from onyx.agent_search.expanded_retrieval.states import QueryResult + + +def _calculate_sub_question_retrieval_stats( + verified_documents: list[InferenceSection], + expanded_retrieval_results: list[QueryResult], +) -> dict[str, float | int]: + chunk_scores = defaultdict(lambda: defaultdict(list)) + for expanded_retrieval_result in expanded_retrieval_results: + for doc in expanded_retrieval_result.documents_for_query: + doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" + chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) + + verified_doc_chunk_ids = [ + f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}" + for verified_document in verified_documents + ] + dismissed_doc_chunk_ids = [] + + raw_chunk_stats = defaultdict(float) + for doc_chunk_id, chunk_data in chunk_scores.items(): + if doc_chunk_id in verified_doc_chunk_ids: + raw_chunk_stats["verified_count"] += 1 + raw_chunk_stats["verified_scores"] += np.mean(chunk_data["score"]) + else: + raw_chunk_stats["rejected_count"] += 1 + raw_chunk_stats["rejected_scores"] += np.mean(chunk_data["score"]) + dismissed_doc_chunk_ids.append(doc_chunk_id) + + if raw_chunk_stats["verified_count"] == 0: + verified_avg_scores = 0 + else: + verified_avg_scores = ( + raw_chunk_stats["verified_scores"] / raw_chunk_stats["verified_count"] + ) + + chunk_stats = { + "verified_count": raw_chunk_stats["verified_count"], + "verified_avg_scores": verified_avg_scores, + "rejected_count": raw_chunk_stats["rejected_count"], + "rejected_avg_scores": raw_chunk_stats["rejected_scores"] + / raw_chunk_stats["rejected_count"], + "verified_doc_chunk_ids": verified_doc_chunk_ids, + "dismissed_doc_chunk_ids": dismissed_doc_chunk_ids, + } + + return chunk_stats def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: + sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats( + verified_documents=state["verified_documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + ) return ExpandedRetrievalOutput( expanded_retrieval_result=ExpandedRetrievalResult( expanded_queries_results=state["expanded_retrieval_results"], all_documents=state["reranked_documents"], + sub_question_retrieval_stats=sub_question_retrieval_stats, ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 71c845cd6f..898c168224 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -16,13 +16,13 @@ class QueryResult(BaseModel): query: str documents_for_query: list[InferenceSection] - chunk_ids: list[str] stats: dict[str, Any] class ExpandedRetrievalResult(BaseModel): expanded_queries_results: list[QueryResult] all_documents: list[InferenceSection] + sub_question_retrieval_stats: dict ### States ### @@ -35,6 +35,7 @@ class DocVerificationUpdate(TypedDict): class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: Annotated[list[dict[str, float | int]], add] class QueryExpansionUpdate(TypedDict): diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index dc09000435..11c203df2e 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -12,6 +12,9 @@ from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, ) +from onyx.agent_search.main.nodes.generate_initial_BASE_answer import ( + generate_initial_base_answer, +) from onyx.agent_search.main.nodes.ingest_answers import ingest_answers from onyx.agent_search.main.nodes.ingest_initial_retrieval import ( ingest_initial_retrieval, @@ -54,6 +57,10 @@ def main_graph_builder() -> StateGraph: node="generate_initial_answer", action=generate_initial_answer, ) + graph.add_node( + node="generate_initial_base_answer", + action=generate_initial_base_answer, + ) ### Add edges ### @@ -86,7 +93,11 @@ def main_graph_builder() -> StateGraph: end_key="generate_initial_answer", ) graph.add_edge( - start_key="generate_initial_answer", + start_key=["ingest_answers", "ingest_initial_retrieval"], + end_key="generate_initial_base_answer", + ) + graph.add_edge( + start_key=["generate_initial_answer", "generate_initial_base_answer"], end_key=END, ) @@ -94,6 +105,8 @@ def main_graph_builder() -> StateGraph: if __name__ == "__main__": + pass + from onyx.db.engine import get_session_context_manager from onyx.llm.factory import get_default_llms from onyx.context.search.models import SearchRequest @@ -101,16 +114,21 @@ def main_graph_builder() -> StateGraph: graph = main_graph_builder() compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() - search_request = SearchRequest( - query="what can you do with onyx or danswer?", - ) + with get_session_context_manager() as db_session: + durations = [] + chunk_expansion_ratios = [] + support_effectiveness_ratios = [] + + search_request = SearchRequest(query="Who created Excel?") + inputs = MainInput( search_request=search_request, primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, ) + for thing in compiled_graph.stream( input=inputs, # stream_mode="debug", diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py new file mode 100644 index 0000000000..444d12a2f2 --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py @@ -0,0 +1,31 @@ +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import InitialAnswerBASEUpdate +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: + print("---GENERATE INITIAL BASE ANSWER---") + + question = state["search_request"].query + original_question_docs = state["all_original_question_documents"] + + msg = [ + HumanMessage( + content=INITIAL_RAG_BASE_PROMPT.format( + question=question, + context=format_docs(original_question_docs), + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + answer = response.pretty_repr() + + print() + print(f"---INITIAL BASE ANSWER START--- {answer} ---INITIAL BASE ANSWER END---") + return InitialAnswerBASEUpdate(initial_base_answer=answer) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 828472d6ea..27fcaf6e48 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -1,27 +1,108 @@ +from backend.onyx.agent_search.answer_question.states import QuestionAnswerResults from langchain_core.messages import HumanMessage +from onyx.agent_search.main.states import AgentStats from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs +def _calculate_initial_agent_stats( + decomp_answer_results: list[QuestionAnswerResults], original_question_stats: dict +) -> AgentStats: + initial_agent_dict = { + "sub_questions": {}, + "original_question": {}, + "agent_effectiveness": {}, + } + + verified_document_chunk_ids = [] + support_scores = 0 + + for decomp_answer_result in decomp_answer_results: + verified_document_chunk_ids += ( + decomp_answer_result.sub_question_retrieval_stats["verified_doc_chunk_ids"] + ) + support_scores += decomp_answer_result.sub_question_retrieval_stats[ + "verified_avg_scores" + ] + + verified_document_chunk_ids = list(set(verified_document_chunk_ids)) + + # Calculate sub-question stats + if verified_document_chunk_ids: + sub_question_stats = { + "num_verified_documents": len(verified_document_chunk_ids), + "verified_avg_score": support_scores / len(decomp_answer_results), + } + else: + sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None} + initial_agent_dict["sub_questions"].update(sub_question_stats) + + # Get original question stats + initial_agent_dict["original_question"].update( + { + "num_verified_documents": original_question_stats.get("verified_count", 0), + "verified_avg_score": original_question_stats.get( + "verified_avg_scores", None + ), + } + ) + + # Calculate chunk utilization ratio + sub_verified = initial_agent_dict["sub_questions"]["num_verified_documents"] + orig_verified = initial_agent_dict["original_question"]["num_verified_documents"] + + chunk_ratio = None + if orig_verified > 0: + chunk_ratio = sub_verified / orig_verified if sub_verified > 0 else 0 + elif sub_verified > 0: + chunk_ratio = 10 + + initial_agent_dict["agent_effectiveness"]["utilized_chunk_ratio"] = chunk_ratio + + if ( + initial_agent_dict["original_question"]["verified_avg_score"] is None + and initial_agent_dict["sub_questions"]["verified_avg_score"] is None + ): + initial_agent_dict["agent_effectiveness"]["support_ratio"] = None + elif initial_agent_dict["original_question"]["verified_avg_score"] is None: + initial_agent_dict["agent_effectiveness"]["support_ratio"] = 10 + elif initial_agent_dict["sub_questions"]["verified_avg_score"] is None: + initial_agent_dict["agent_effectiveness"]["support_ratio"] = 0 + else: + initial_agent_dict["agent_effectiveness"]["support_ratio"] = ( + initial_agent_dict["sub_questions"]["verified_avg_score"] + / initial_agent_dict["original_question"]["verified_avg_score"] + ) + + return initial_agent_dict + + def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print("---GENERATE INITIAL---") question = state["search_request"].query - docs = state["documents"] + sub_question_docs = state["documents"] all_original_question_documents = state["all_original_question_documents"] - combined_docs = docs + all_original_question_documents + # combined_docs = dedup_inference_sections(docs + all_original_question_documents) + + net_new_original_question_docs = [] + for all_original_question_doc in all_original_question_documents: + if all_original_question_doc not in sub_question_docs: + net_new_original_question_docs.append(all_original_question_doc) decomp_answer_results = state["decomp_answer_results"] good_qa_list: list[str] = [] + decomp_questions = [] _SUB_QUESTION_ANSWER_TEMPLATE = """ Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n """ for decomp_answer_result in decomp_answer_results: + decomp_questions.append(decomp_answer_result.question) if ( decomp_answer_result.quality.lower() == "yes" and len(decomp_answer_result.answer) > 0 @@ -40,8 +121,9 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: HumanMessage( content=INITIAL_RAG_PROMPT.format( question=question, - context=format_docs(combined_docs), answered_sub_questions=sub_question_answer_str, + sub_question_docs_context=format_docs(sub_question_docs), + additional_relevant_docs=format_docs(net_new_original_question_docs), ) ) ] @@ -51,5 +133,17 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: response = model.invoke(msg) answer = response.pretty_repr() - print(answer) - return InitialAnswerUpdate(initial_answer=answer) + initial_agent_stats = _calculate_initial_agent_stats( + state["decomp_answer_results"], state["sub_question_retrieval_stats"] + ) + + print("") + print( + f"---INITIAL AGENT ANSWER START--- {answer} ---INITIAL AGENT ANSWER END---" + ) + + return InitialAnswerUpdate( + initial_answer=answer, + initial_agent_stats=initial_agent_stats, + generated_sub_questions=decomp_questions, + ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py index 3cd75860ce..3a8403fe9b 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -10,4 +10,7 @@ def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrieva all_original_question_documents=state[ "expanded_retrieval_result" ].all_documents, + sub_question_retrieval_stats=state[ + "expanded_retrieval_result" + ].sub_question_retrieval_stats, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 6fb44da490..48cf31faa0 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -2,12 +2,23 @@ from typing import Annotated from typing import TypedDict +from pydantic import BaseModel + from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection +## Models + + +class AgentStats(BaseModel): + sub_question_stats: list[dict[str, float | int]] + original_question_stats: dict[str, float | int] + agent_stats: dict[str, float | int] + + ### States ### ## Update States @@ -17,8 +28,14 @@ class BaseDecompUpdate(TypedDict): initial_decomp_questions: list[str] +class InitialAnswerBASEUpdate(TypedDict): + initial_base_answer: str + + class InitialAnswerUpdate(TypedDict): initial_answer: str + initial_agent_stats: dict + generated_sub_questions: list[str] class DecompAnswersUpdate(TypedDict): @@ -31,6 +48,7 @@ class ExpandedRetrievalUpdate(TypedDict): list[InferenceSection], dedup_inference_sections ] original_question_retrieval_results: list[QueryResult] + sub_question_retrieval_stats: dict ## Graph Input State @@ -48,6 +66,7 @@ class MainState( MainInput, BaseDecompUpdate, InitialAnswerUpdate, + InitialAnswerBASEUpdate, DecompAnswersUpdate, ExpandedRetrievalUpdate, ): @@ -58,7 +77,7 @@ class MainState( class MainOutput(TypedDict): - """ - This is not used because defining the output only matters for filtering the output of - a .invoke() call but we are streaming so we just yield the entire state. - """ + initial_answer: str + initial_base_answer: str + initial_agent_stats: dict + generated_sub_questions: list[str] diff --git a/backend/onyx/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agent_search/shared_graph_utils/calculations.py new file mode 100644 index 0000000000..4ff6eb0e03 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/calculations.py @@ -0,0 +1,11 @@ +def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float: + shift = 0 + for rank_first, doc_id in enumerate(list1[:top_n], 1): + try: + rank_second = list2.index(doc_id) + 1 + except ValueError: + rank_second = len(list2) # Document not found in second list + + shift += (rank_first - rank_second) ** 2 / (rank_first * rank_second) + + return shift / top_n diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 229a980762..3763b88c1a 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -20,22 +20,31 @@ BASE_RAG_PROMPT = """ \n You are an assistant for question-answering tasks. Use the context provided below - and only the - provided context - to answer the question. If you don't know the answer or if the provided context is - empty, just say "I don't know". Do not use your internal knowledge! + provided context - to answer the given question. (Note that the answer is in service of anserwing a broader + question, given below as 'motivation'.) Again, only use the provided context and do not use your internal knowledge! If you cannot answer the question based on the context, say "I don't know". It is a matter of life and death that you do NOT use your internal knowledge, just the provided information! - Use three sentences maximum and keep the answer concise. - answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n + Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. + (But keep other details as well.) + + If you don't know the answer or if the provided context is + empty, just say "I don't know". Do not use your internal knowledge! + + \nQuestion:\n {question} \n + + \nContext:\n {context} \n + + Motivation:\n {original_question} \n\n \n\n Answer:""" -SUB_CHECK_PROMPT = """ \n +SUB_CHECK_PROMPT = """ Your task is to see whether a given answer addresses a given question. Please do not use any internal knowledge you may have - just focus on whether the answer - as given seems to address the question as given. + as given seems to largely address the question as given. Here is the question: \n ------- \n {question} @@ -378,6 +387,30 @@ Answer: """ +INITIAL_RAG_BASE_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the information provided below - and only the + provided information - to answer the provided question. + + The information provided below consists of a number of documents that were also deemed relevant for the question. + + If you don't know the answer or if the provided information is empty or insufficient, just say + "I don't know". Do not use your internal knowledge! + + Again, only use the provided information and do not use your internal knowledge! It is a matter of life + and death that you do NOT use your internal knowledge, just the provided information! + + Try to keep your answer concise. + + And here is the question and the provided information: + \n + \nQuestion:\n {question} \n + + \nContext:\n {context} \n\n + \n\n + + Answer:""" + + INITIAL_RAG_PROMPT = """ \n You are an assistant for question-answering tasks. Use the information provided below - and only the provided information - to answer the provided question. @@ -390,20 +423,19 @@ If you don't know the answer or if the provided information is empty or insufficient, just say "I don't know". Do not use your internal knowledge! - Again, only use the provided informationand do not use your internal knowledge! It is a matter of life + Again, only use the provided information and do not use your internal knowledge! It is a matter of life and death that you do NOT use your internal knowledge, just the provided information! Try to keep your answer concise. And here is the question and the provided information: \n - \nQuestion:\n {question} - - \nAnswered Sub-questions:\n {answered_sub_questions} - - \nContext:\n {context} \n\n - \n\n + \nQuestion:\n {question}\n\n + Answered Sub-questions:\n {answered_sub_questions}\n\n + Documents supporting the sub-questions answers:\n {sub_question_docs_context}\n\n + And here are additional relevant documents:\n\n + {additional_relevant_docs} \n\n\n Answer:""" ENTITY_TERM_PROMPT = """ \n diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test.py new file mode 100644 index 0000000000..45e923123a --- /dev/null +++ b/backend/tests/regression/answer_quality/agent_test.py @@ -0,0 +1,112 @@ +import csv +import datetime +import json +import os + +import yaml + +from onyx.agent_search.main.graph_builder import main_graph_builder +from onyx.agent_search.main.states import MainInput +from onyx.context.search.models import SearchRequest +from onyx.db.engine import get_session_context_manager +from onyx.llm.factory import get_default_llms + +cwd = os.getcwd() +CONFIG = yaml.safe_load( + open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml") +) +INPUT_DIR = CONFIG["agent_test_input_folder"] +OUTPUT_DIR = CONFIG["agent_test_output_folder"] + + +graph = main_graph_builder() +compiled_graph = graph.compile() +primary_llm, fast_llm = get_default_llms() + +# create a local json test data file and use it here + + +input_file_object = open( + f"{INPUT_DIR}/agent_test_data.json", +) +output_file = f"{OUTPUT_DIR}/agent_test_output.csv" + + +test_data = json.load(input_file_object) +examples = test_data["examples"] + +with get_session_context_manager() as db_session: + output_data = [] + + for example in examples: + example_id = example["id"] + example_question = example["question"] + target_sub_questions = example.get("target_sub_questions", []) + num_target_sub_questions = len(target_sub_questions) + search_request = SearchRequest(query=example_question) + + inputs = MainInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + ) + + start_time = datetime.datetime.now() + + question_result = compiled_graph.invoke(input=inputs) + end_time = datetime.datetime.now() + + duration = end_time - start_time + chunk_expansion_ratio = ( + question_result["initial_agent_stats"] + .get("agent_effectiveness", {}) + .get("utilized_chunk_ratio", None) + ) + support_effectiveness_ratio = ( + question_result["initial_agent_stats"] + .get("agent_effectiveness", {}) + .get("support_ratio", None) + ) + generated_sub_questions = question_result.get("generated_sub_questions", []) + num_generated_sub_questions = len(generated_sub_questions) + base_answer = question_result["initial_base_answer"].split("==")[-1] + agent_answer = question_result["initial_answer"].split("==")[-1] + + output_point = { + "example_id": example_id, + "question": example_question, + "duration": duration, + "target_sub_questions": target_sub_questions, + "generated_sub_questions": generated_sub_questions, + "num_target_sub_questions": num_target_sub_questions, + "num_generated_sub_questions": num_generated_sub_questions, + "chunk_expansion_ratio": chunk_expansion_ratio, + "support_effectiveness_ratio": support_effectiveness_ratio, + "base_answer": base_answer, + "agent_answer": agent_answer, + } + + output_data.append(output_point) + + +with open(output_file, "w", newline="") as csvfile: + fieldnames = [ + "example_id", + "question", + "duration", + "target_sub_questions", + "generated_sub_questions", + "num_target_sub_questions", + "num_generated_sub_questions", + "chunk_expansion_ratio", + "support_effectiveness_ratio", + "base_answer", + "agent_answer", + ] + + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t") + writer.writeheader() + writer.writerows(output_data) + +print("DONE") From f4c826c4e52822059af787aa53aea8e0a6601a0b Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 20 Dec 2024 14:52:09 -0800 Subject: [PATCH 21/78] regression-test graph vs regulat graph --- .../onyx/agent_search/main/graph_builder.py | 33 +++++++++++-------- .../shared_graph_utils/prompts.py | 4 +-- .../regression/answer_quality/agent_test.py | 2 +- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 11c203df2e..5c697a665d 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -23,7 +23,7 @@ from onyx.agent_search.main.states import MainState -def main_graph_builder() -> StateGraph: +def main_graph_builder(test_mode: bool = False) -> StateGraph: graph = StateGraph( state_schema=MainState, input=MainInput, @@ -57,10 +57,11 @@ def main_graph_builder() -> StateGraph: node="generate_initial_answer", action=generate_initial_answer, ) - graph.add_node( - node="generate_initial_base_answer", - action=generate_initial_base_answer, - ) + if test_mode: + graph.add_node( + node="generate_initial_base_answer", + action=generate_initial_base_answer, + ) ### Add edges ### @@ -92,14 +93,20 @@ def main_graph_builder() -> StateGraph: start_key=["ingest_answers", "ingest_initial_retrieval"], end_key="generate_initial_answer", ) - graph.add_edge( - start_key=["ingest_answers", "ingest_initial_retrieval"], - end_key="generate_initial_base_answer", - ) - graph.add_edge( - start_key=["generate_initial_answer", "generate_initial_base_answer"], - end_key=END, - ) + if test_mode: + graph.add_edge( + start_key=["ingest_answers", "ingest_initial_retrieval"], + end_key="generate_initial_base_answer", + ) + graph.add_edge( + start_key=["generate_initial_answer", "generate_initial_base_answer"], + end_key=END, + ) + else: + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) return graph diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 3763b88c1a..d5d2cd76d0 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -403,9 +403,9 @@ And here is the question and the provided information: \n - \nQuestion:\n {question} \n + \nQuestion:\n {question}\n - \nContext:\n {context} \n\n + \nContext:\n {context}\n\n \n\n Answer:""" diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test.py index 45e923123a..28329c6b33 100644 --- a/backend/tests/regression/answer_quality/agent_test.py +++ b/backend/tests/regression/answer_quality/agent_test.py @@ -19,7 +19,7 @@ OUTPUT_DIR = CONFIG["agent_test_output_folder"] -graph = main_graph_builder() +graph = main_graph_builder(test_mode=True) compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() From fa481019e894e71b93a1d98b10b767f5575bd12c Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 24 Dec 2024 08:16:56 -0800 Subject: [PATCH 22/78] improvements --- .../agent_search/expanded_retrieval/edges.py | 5 +- .../expanded_retrieval/nodes/doc_reranking.py | 37 +++- .../expanded_retrieval/nodes/doc_retrieval.py | 63 ++---- .../nodes/doc_verification.py | 21 +- .../nodes/format_results.py | 11 +- .../nodes/verification_kickoff.py | 1 + .../agent_search/expanded_retrieval/states.py | 12 +- backend/onyx/agent_search/main/edges.py | 30 ++- .../onyx/agent_search/main/graph_builder.py | 2 +- .../agent_search/main/nodes/base_decomp.py | 6 +- .../nodes/generate_initial_BASE_answer.py | 4 +- .../main/nodes/generate_initial_answer.py | 60 +++-- .../agent_search/main/nodes/ingest_answers.py | 5 +- .../shared_graph_utils/calculations.py | 84 ++++++- .../agent_search/shared_graph_utils/models.py | 16 ++ .../shared_graph_utils/prompts.py | 205 ++++++++++++------ backend/onyx/configs/dev_configs.py | 13 ++ .../regression/answer_quality/agent_test.py | 33 ++- 18 files changed, 421 insertions(+), 187 deletions(-) create mode 100644 backend/onyx/configs/dev_configs.py diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 61d994a687..cd5f2c6175 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -8,6 +8,9 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]: + question = state.get("question", state["search_request"].query) + + query_expansions = state.get("expanded_queries", []) + [question] return [ Send( "doc_retrieval", @@ -16,5 +19,5 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab **extract_core_fields(state), ), ) - for query in state["expanded_queries"] + for query in query_expansions ] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index f5a5bd3cd1..ffc90a9b83 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -1,12 +1,43 @@ from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.context.search.pipeline import retrieval_preprocessing +from onyx.context.search.pipeline import search_postprocessing +from onyx.context.search.pipeline import SearchRequest def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: - state["expanded_retrieval_results"] + AGENT_TEST = True + AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS = 10 + verified_documents = state["verified_documents"] - reranked_documents = verified_documents + + # Rerank post retrieval and verification. First, create a search query + # then create the list of reranked sections + + _search_query = retrieval_preprocessing( + search_request=SearchRequest(query=state["question"]), + user=None, + llm=state["fast_llm"], + db_session=state["db_session"], + ) + + reranked_documents = list( + search_postprocessing( + search_query=_search_query, + retrieved_sections=verified_documents, + llm=state["fast_llm"], + ) + )[ + 0 + ] # only get the reranked szections, not the SectionRelevancePiece + + if AGENT_TEST: + fit_scores = get_fit_scores(verified_documents, reranked_documents) + else: + fit_scores = None return DocRerankingUpdate( - reranked_documents=reranked_documents, + reranked_documents=reranked_documents[:AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS], + fit_scores=fit_scores, ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index c5fec00fea..25b1560ca5 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,11 +1,9 @@ -from backend.onyx.agent_search.shared_graph_utils.calculations import ( - calculate_rank_shift, -) - from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.expanded_retrieval.states import RetrievalInput -from onyx.context.search.models import InferenceSection +from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.configs.dev_configs import AGENT_TEST +from onyx.configs.dev_configs import AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline @@ -26,7 +24,7 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: fast_llm = state["fast_llm"] query_to_retrieve = state["query_to_retrieve"] - documents: list[InferenceSection] = SearchPipeline( + search_results = SearchPipeline( search_request=SearchRequest( query=query_to_retrieve, ), @@ -36,55 +34,24 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: db_session=state["db_session"], ) - # Initial calculations of scores for the retrieval quality - - ranked_sections = { - "initial": documents.final_context_sections, - "reranked": documents.reranked_sections, - } - - fit_scores = {} + retrieved_docs = search_results._get_sections()[ + :AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS + ] - for rank_type, docs in ranked_sections.items(): - fit_scores[rank_type] = {} - for i in [1, 5, 10]: - fit_scores[rank_type][i] = ( - sum([doc.center_chunk.score for doc in docs[:i]]) / i - ) - - fit_scores[rank_type]["fit_score"] = ( - 1 - / 3 - * ( - fit_scores[rank_type][1] - + fit_scores[rank_type][5] - + fit_scores[rank_type][10] - ) + if AGENT_TEST: + fit_scores = get_fit_scores( + retrieved_docs, + search_results.reranked_sections[:AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS], ) - - fit_scores[rank_type]["fit_score"] = fit_scores[rank_type][1] - - fit_scores[rank_type]["chunk_ids"] = [doc.center_chunk.chunk_id for doc in docs] - - fit_score_lift = ( - fit_scores["reranked"]["fit_score"] / fit_scores["initial"]["fit_score"] - ) - - average_rank_change = calculate_rank_shift( - fit_scores["initial"]["chunk_ids"], fit_scores["reranked"]["chunk_ids"] - ) - - fit_scores["rerank_effect"] = average_rank_change - fit_scores["fit_score_lift"] = fit_score_lift - - documents = documents.reranked_sections[:4] + else: + fit_scores = None expanded_retrieval_result = QueryResult( query=query_to_retrieve, - documents_for_query=documents, + search_results=retrieved_docs, stats=fit_scores, ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], - retrieved_documents=documents, + retrieved_documents=retrieved_docs, ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index 3abebfcf2e..ea7c000eff 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -1,5 +1,6 @@ +import json + from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs from onyx.agent_search.expanded_retrieval.states import DocVerificationInput from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate @@ -18,32 +19,30 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: verified_documents: list[InferenceSection] """ - original_query = state["search_request"].query + state["search_request"].query + question = state["question"] doc_to_verify = state["doc_to_verify"] document_content = doc_to_verify.combined_content msg = [ HumanMessage( content=VERIFIER_PROMPT.format( - question=original_query, document_content=document_content + question=question, document_content=document_content ) ) ] fast_llm = state["fast_llm"] - response = list( - fast_llm.stream( - prompt=msg, - ) + response = json.loads( + fast_llm.invoke(msg, structured_response_format=BinaryDecision).content ) - response_string = merge_message_runs(response, chunk_separator="")[0].content + # response_string = response.content.get("decision", "no").lower() # Convert string response to proper dictionary format - decision_dict = {"decision": response_string.lower()} - formatted_response = BinaryDecision.model_validate(decision_dict) + # decision_dict = {"decision": response.content.lower()} verified_documents = [] - if formatted_response.decision == "yes": + if response["decision"] == "yes": verified_documents.append(doc_to_verify) return DocVerificationUpdate( diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index cf1463d5b5..7d9d660263 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -15,7 +15,7 @@ def _calculate_sub_question_retrieval_stats( ) -> dict[str, float | int]: chunk_scores = defaultdict(lambda: defaultdict(list)) for expanded_retrieval_result in expanded_retrieval_results: - for doc in expanded_retrieval_result.documents_for_query: + for doc in expanded_retrieval_result.search_results: doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) @@ -42,12 +42,17 @@ def _calculate_sub_question_retrieval_stats( raw_chunk_stats["verified_scores"] / raw_chunk_stats["verified_count"] ) + rejected_scores = raw_chunk_stats.get("rejected_scores", None) + if rejected_scores is not None: + rejected_avg_scores = rejected_scores / raw_chunk_stats["rejected_count"] + else: + rejected_avg_scores = None + chunk_stats = { "verified_count": raw_chunk_stats["verified_count"], "verified_avg_scores": verified_avg_scores, "rejected_count": raw_chunk_stats["rejected_count"], - "rejected_avg_scores": raw_chunk_stats["rejected_scores"] - / raw_chunk_stats["rejected_count"], + "rejected_avg_scores": rejected_avg_scores, "verified_doc_chunk_ids": verified_doc_chunk_ids, "dismissed_doc_chunk_ids": dismissed_doc_chunk_ids, } diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index 81fe7f9229..05b2465263 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -23,6 +23,7 @@ def verification_kickoff( node="doc_verification", arg=DocVerificationInput( doc_to_verify=doc, + question=state["question"], **extract_core_fields(state), ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 898c168224..55dae753e5 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -9,14 +9,13 @@ from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection - ### Models ### class QueryResult(BaseModel): query: str - documents_for_query: list[InferenceSection] - stats: dict[str, Any] + search_results: list[InferenceSection] + stats: dict[str, Any] | None class ExpandedRetrievalResult(BaseModel): @@ -40,6 +39,7 @@ class DocRerankingUpdate(TypedDict): class QueryExpansionUpdate(TypedDict): expanded_queries: list[str] + question: str class DocRetrievalUpdate(TypedDict): @@ -78,9 +78,11 @@ class ExpandedRetrievalOutput(TypedDict): ## Conditional Input States -class DocVerificationInput(CoreState): +class DocVerificationInput(ExpandedRetrievalInput): doc_to_verify: InferenceSection + query_to_retrieve: str + question: str -class RetrievalInput(CoreState): +class RetrievalInput(ExpandedRetrievalInput): query_to_retrieve: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 7791498d3e..492214f7c2 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -3,6 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainInput @@ -10,16 +11,25 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: - return [ - Send( - "answer_query", - AnswerQuestionInput( - **extract_core_fields(state), - question=question, - ), - ) - for question in state["initial_decomp_questions"] - ] + if len(state["initial_decomp_questions"]) > 0: + return [ + Send( + "answer_query", + AnswerQuestionInput( + **extract_core_fields(state), + question=question, + ), + ) + for question in state["initial_decomp_questions"] + ] + + else: + return [ + Send( + "ingest_answers", + AnswerQuestionOutput(), + ) + ] def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 5c697a665d..f23ee9176f 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -82,7 +82,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_conditional_edges( source="base_decomp", path=parallelize_decompozed_answer_queries, - path_map=["answer_query"], + path_map=["answer_query", "ingest_answers"], ) graph.add_edge( start_key="answer_query", diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py index 05b095794b..8285680cb9 100644 --- a/backend/onyx/agent_search/main/nodes/base_decomp.py +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -2,7 +2,9 @@ from onyx.agent_search.main.states import BaseDecompUpdate from onyx.agent_search.main.states import MainState -from onyx.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import ( + INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, +) from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string @@ -11,7 +13,7 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: msg = [ HumanMessage( - content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), + content=INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(question=question), ) ] diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py index 444d12a2f2..00bc742f06 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py @@ -27,5 +27,7 @@ def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: answer = response.pretty_repr() print() - print(f"---INITIAL BASE ANSWER START--- {answer} ---INITIAL BASE ANSWER END---") + print( + f"\n\n---INITIAL BASE ANSWER START---\n\nBase: {answer}\n\n ---INITIAL BASE ANSWER END---\n\n" + ) return InitialAnswerBASEUpdate(initial_base_answer=answer) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 27fcaf6e48..1cd1863852 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -1,10 +1,16 @@ +import json + from backend.onyx.agent_search.answer_question.states import QuestionAnswerResults from langchain_core.messages import HumanMessage from onyx.agent_search.main.states import AgentStats from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import ( + INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, +) from onyx.agent_search.shared_graph_utils.utils import format_docs @@ -17,6 +23,9 @@ def _calculate_initial_agent_stats( "agent_effectiveness": {}, } + orig_verified = original_question_stats["verified_count"] + orig_support_score = original_question_stats["verified_avg_scores"] + verified_document_chunk_ids = [] support_scores = 0 @@ -52,7 +61,6 @@ def _calculate_initial_agent_stats( # Calculate chunk utilization ratio sub_verified = initial_agent_dict["sub_questions"]["num_verified_documents"] - orig_verified = initial_agent_dict["original_question"]["num_verified_documents"] chunk_ratio = None if orig_verified > 0: @@ -63,18 +71,18 @@ def _calculate_initial_agent_stats( initial_agent_dict["agent_effectiveness"]["utilized_chunk_ratio"] = chunk_ratio if ( - initial_agent_dict["original_question"]["verified_avg_score"] is None + orig_support_score is None and initial_agent_dict["sub_questions"]["verified_avg_score"] is None ): initial_agent_dict["agent_effectiveness"]["support_ratio"] = None - elif initial_agent_dict["original_question"]["verified_avg_score"] is None: + elif orig_support_score is None: initial_agent_dict["agent_effectiveness"]["support_ratio"] = 10 elif initial_agent_dict["sub_questions"]["verified_avg_score"] is None: initial_agent_dict["agent_effectiveness"]["support_ratio"] = 0 else: initial_agent_dict["agent_effectiveness"]["support_ratio"] = ( initial_agent_dict["sub_questions"]["verified_avg_score"] - / initial_agent_dict["original_question"]["verified_avg_score"] + / orig_support_score ) return initial_agent_dict @@ -86,7 +94,9 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: question = state["search_request"].query sub_question_docs = state["documents"] all_original_question_documents = state["all_original_question_documents"] - # combined_docs = dedup_inference_sections(docs + all_original_question_documents) + relevant_docs = dedup_inference_sections( + sub_question_docs, all_original_question_documents + ) net_new_original_question_docs = [] for all_original_question_doc in all_original_question_documents: @@ -104,7 +114,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: for decomp_answer_result in decomp_answer_results: decomp_questions.append(decomp_answer_result.question) if ( - decomp_answer_result.quality.lower() == "yes" + decomp_answer_result.quality.lower().startswith("yes") and len(decomp_answer_result.answer) > 0 and decomp_answer_result.answer != "I don't know" ): @@ -117,16 +127,25 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT.format( - question=question, - answered_sub_questions=sub_question_answer_str, - sub_question_docs_context=format_docs(sub_question_docs), - additional_relevant_docs=format_docs(net_new_original_question_docs), + if len(good_qa_list) > 0: + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + answered_sub_questions=sub_question_answer_str, + relevant_docs=format_docs(relevant_docs), + ) ) - ) - ] + ] + else: + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format( + question=question, + relevant_docs=format_docs(relevant_docs), + ) + ) + ] # Grader model = state["fast_llm"] @@ -137,10 +156,13 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: state["decomp_answer_results"], state["sub_question_retrieval_stats"] ) - print("") - print( - f"---INITIAL AGENT ANSWER START--- {answer} ---INITIAL AGENT ANSWER END---" - ) + print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + + print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") + + print(json.dumps(initial_agent_stats, indent=4)) + + print("\n\n ---INITIAL AGENT ANSWER END---\n\n") return InitialAnswerUpdate( initial_answer=answer, diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index c86f3f3104..5eac7670e9 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -5,11 +5,12 @@ def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: documents = [] - for answer_result in state["answer_results"]: + answer_results = state.get("answer_results", []) + for answer_result in answer_results: documents.extend(answer_result.documents) return DecompAnswersUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here documents=dedup_inference_sections(documents, []), - decomp_answer_results=state["answer_results"], + decomp_answer_results=answer_results, ) diff --git a/backend/onyx/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agent_search/shared_graph_utils/calculations.py index 4ff6eb0e03..93947c252f 100644 --- a/backend/onyx/agent_search/shared_graph_utils/calculations.py +++ b/backend/onyx/agent_search/shared_graph_utils/calculations.py @@ -1,3 +1,17 @@ +import numpy as np + +from onyx.agent_search.shared_graph_utils.models import FitScoreMetrics +from onyx.agent_search.shared_graph_utils.models import FitScores +from onyx.context.search.models import InferenceSection +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def unique_chunk_id(doc: InferenceSection) -> str: + return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" + + def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float: shift = 0 for rank_first, doc_id in enumerate(list1[:top_n], 1): @@ -6,6 +20,74 @@ def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float: except ValueError: rank_second = len(list2) # Document not found in second list - shift += (rank_first - rank_second) ** 2 / (rank_first * rank_second) + shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second) return shift / top_n + + +def get_fit_scores( + pre_reranked_results: list[InferenceSection], + post_reranked_results: list[InferenceSection], +) -> dict | None: + """ + Calculate retrieval metrics for search purposes + """ + + if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0: + return None + + ranked_sections = { + "initial": pre_reranked_results, + "reranked": post_reranked_results, + } + + fit_eval = {} + + fit_eval: FitScores = FitScores( + fit_score_lift=0, + rerank_effect=0, + fit_scores={ + "initial": FitScoreMetrics(scores={}, chunk_ids=[]), + "reranked": FitScoreMetrics(scores={}, chunk_ids=[]), + }, + ) + + # logger.error("wwwww") + + for rank_type, docs in ranked_sections.items(): + print(f"rank_type: {rank_type}") + + for i in [1, 5, 10]: + fit_eval.fit_scores[rank_type].scores[str(i)] = ( + sum([doc.center_chunk.score for doc in docs[:i]]) / i + ) + + fit_eval.fit_scores[rank_type].scores["fit_score"] = ( + 1 + / 3 + * ( + fit_eval.fit_scores[rank_type].scores["1"] + + fit_eval.fit_scores[rank_type].scores["5"] + + fit_eval.fit_scores[rank_type].scores["10"] + ) + ) + + fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[ + rank_type + ].scores["1"] + + fit_eval.fit_scores[rank_type].chunk_ids = [ + unique_chunk_id(doc) for doc in docs + ] + + fit_eval.fit_score_lift = ( + fit_eval.fit_scores["reranked"].scores["fit_score"] + / fit_eval.fit_scores["initial"].scores["fit_score"] + ) + + fit_eval.rerank_effect = calculate_rank_shift( + fit_eval.fit_scores["initial"].chunk_ids, + fit_eval.fit_scores["reranked"].chunk_ids, + ) + + return fit_eval diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py index 162d651fe5..198ae0c475 100644 --- a/backend/onyx/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -10,3 +10,19 @@ class RewrittenQueries(BaseModel): class BinaryDecision(BaseModel): decision: Literal["yes", "no"] + + +class BinaryDecisionWithReasoning(BaseModel): + reasoning: str + decision: Literal["yes", "no"] + + +class FitScoreMetrics(BaseModel): + scores: dict[str, float] + chunk_ids: list[str] + + +class FitScores(BaseModel): + fit_score_lift: float + rerank_effect: float + fit_scores: dict[str, FitScoreMetrics] diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index d5d2cd76d0..918e14c8f7 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -30,21 +30,18 @@ Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. (But keep other details as well.) - If you don't know the answer or if the provided context is - empty, just say "I don't know". Do not use your internal knowledge! - - \nQuestion:\n {question} \n - \nContext:\n {context} \n Motivation:\n {original_question} \n\n \n\n - Answer:""" + And here is the question I want you to answer based on the context above (with the motivation in mind): + \n--\n {question} \n--\n + """ SUB_CHECK_PROMPT = """ Your task is to see whether a given answer addresses a given question. Please do not use any internal knowledge you may have - just focus on whether the answer - as given seems to largely address the question as given. + as given seems to largely address the question as given, or at least addresses part of the question. Here is the question: \n ------- \n {question} @@ -53,7 +50,7 @@ \n ------- \n {base_answer} \n ------- \n - Please answer with yes or no:""" + Does the suggested answer address the question? Please answer with yes or no:""" BASE_CHECK_PROMPT = """ \n @@ -73,37 +70,50 @@ \n ------- \n Please answer with yes or no:""" -VERIFIER_PROMPT = """ \n - Please check whether the document provided below seems to be relevant - to get an answer to the provided question. Please - only answer with 'yes' or 'no' \n - Here is the initial question: - \n ------- \n - {question} - \n ------- \n - Here is the document text: - \n ------- \n - {document_content} - \n ------- \n - Please answer with yes or no:""" +VERIFIER_PROMPT = """ +You are supposed to judge whether a document text contains data or information that is potentially relevant for a question. + +Here is a document text that you can take as a fact: +-- +DOCUMENT INFORMATION: +{document_content} +-- + +Do you think that this information is useful and relevant to answer the following question? +(Other documents may supply additional information, so do not worry if the provided information +is not enough to answer the question, but it needs to be relevant to the question.) +-- +QUESTION: +{question} +-- + +Please answer with 'yes' or 'no' and format your answer as a json object with the following format: +{{"decision": }} + +AANSWER: + +""" INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n - Please decompose an initial user question into not more than 4 appropriate sub-questions that help to - answer the original question. The purpose for this decomposition is to isolate individulal entities - (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales - for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our - sales with company A' + 'what is our market share with company A' + 'is company A a reference customer - for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n +If you think it is helpful, please decompose an initial user question into not more +than 4 appropriate sub-questions that help to answer the original question. +The purpose for this decomposition is to isolate individulal entities +(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales +for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our +sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. - Here is the initial question: - \n ------- \n - {question} - \n ------- \n +Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too. - Please formulate your answer as a list of subquestions: +Here is the initial question: +\n ------- \n +{question} +\n ------- \n - Answer: - """ +Please formulate your answer as a list of subquestions: + +Answer: +""" REWRITE_PROMPT_SINGLE = """ \n Please convert an initial user question into a more appropriate search query for retrievel from a @@ -364,6 +374,29 @@ """ +INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """ +If you think it is helpful, please decompose an initial user question into 2 or 4 appropriate sub-questions that help to +answer the original question. The purpose for this decomposition is to + 1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A', + 'what are sales for company B')] + 2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A', + 'what is our market share with company A', 'is company A a reference customer for us', etc.]) + 3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally + familiar with the entity, then you can decompose the question into sub-questions that are more specific to components + (i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X', + 'what do we do to improve stability of product X', ...]) + +If you think that a decomposition is not needed or helpful, please just return an empty list. That is ok too. + +Here is the initial question: +------- +{question} +------- +Please formulate your answer as a list of json objects with the following format: +[{{"sub_question": }}, ...] + +Answer:""" + INITIAL_DECOMPOSITION_PROMPT = """ \n Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to answer the original question. The purpose for this decomposition is to isolate individulal entities @@ -388,55 +421,91 @@ """ INITIAL_RAG_BASE_PROMPT = """ \n - You are an assistant for question-answering tasks. Use the information provided below - and only the - provided information - to answer the provided question. +You are an assistant for question-answering tasks. Use the information provided below - and only the +provided information - to answer the provided question. - The information provided below consists of a number of documents that were also deemed relevant for the question. +The information provided below consists ofa number of documents that were deemed relevant for the question. - If you don't know the answer or if the provided information is empty or insufficient, just say - "I don't know". Do not use your internal knowledge! +IMPORTANT RULES: +- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. +You may give some additional facts you learned, but do not try to invent an answer. +- If the information is empty or irrelevant, just say "I don't know". +- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. - Again, only use the provided information and do not use your internal knowledge! It is a matter of life - and death that you do NOT use your internal knowledge, just the provided information! +Try to keep your answer concise. - Try to keep your answer concise. +Here is the contextual information from the document store: +\n ------- \n +{context} \n\n\n +\n ------- \n +And here is the question I want you to answer based on the context above (with the motivation in mind): +\n--\n {question} \n--\n +Answer:""" - And here is the question and the provided information: - \n - \nQuestion:\n {question}\n - \nContext:\n {context}\n\n - \n\n +INITIAL_RAG_PROMPT = """ \n +You are an assistant for question-answering tasks. Use the information provided below - and only the +provided information - to answer the provided question. - Answer:""" +The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say "I don't know". + - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. -INITIAL_RAG_PROMPT = """ \n - You are an assistant for question-answering tasks. Use the information provided below - and only the - provided information - to answer the provided question. +Again, you should be sure that the answer is supported by the information provided! - The information provided below consists of: - 1) a number of answered sub-questions - these are very important(!) and definitely should be - considered to answer the question. - 2) a number of documents that were also deemed relevant for the question. +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. - If you don't know the answer or if the provided information is empty or insufficient, just say - "I don't know". Do not use your internal knowledge! +Here is the contextual information: +\n-------\n +*Answered Sub-questions (these should really matter!): +{answered_sub_questions} - Again, only use the provided information and do not use your internal knowledge! It is a matter of life - and death that you do NOT use your internal knowledge, just the provided information! +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n - Try to keep your answer concise. +{relevant_docs} - And here is the question and the provided information: - \n - \nQuestion:\n {question}\n\n - Answered Sub-questions:\n {answered_sub_questions}\n\n - Documents supporting the sub-questions answers:\n {sub_question_docs_context}\n\n +\n-------\n +\n +And here is the question I want you to answer based on the information above: +\n--\n +{question} +\n--\n\n +Answer:""" - And here are additional relevant documents:\n\n - {additional_relevant_docs} \n\n\n - Answer:""" +INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """ +You are an assistant for question-answering tasks. Use the information provided below +- and only the provided information - to answer the provided question. +The information provided below consists of a number of documents that were deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is irrelevant, just say "I don't know". + - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + +Again, you should be sure that the answer is supported by the information provided! + +Try to keep your answer concise. + +Here are is the relevant context information: +\n-------\n +{relevant_docs} +\n-------\n + +And here is the question I want you to answer based on the context above +\n--\n +{question} +\n--\n + +Answer:""" ENTITY_TERM_PROMPT = """ \n Based on the original question and the context retieved from a dataset, please generate a list of diff --git a/backend/onyx/configs/dev_configs.py b/backend/onyx/configs/dev_configs.py new file mode 100644 index 0000000000..c9a018e8e8 --- /dev/null +++ b/backend/onyx/configs/dev_configs.py @@ -0,0 +1,13 @@ +import os + +from backend.onyx.configs.chat_configs import NUM_RETURNED_HITS + + +##### +# Agent Configs +##### + +AGENT_TEST = os.environ.get("AGENT_TEST", False) +AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS = os.environ.get( + "MAX_AGENT_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS +) diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test.py index 28329c6b33..d3087c8a6f 100644 --- a/backend/tests/regression/answer_quality/agent_test.py +++ b/backend/tests/regression/answer_quality/agent_test.py @@ -33,13 +33,17 @@ test_data = json.load(input_file_object) -examples = test_data["examples"] +example_data = test_data["examples"] +example_ids = test_data["example_ids"] with get_session_context_manager() as db_session: output_data = [] - for example in examples: + for example in example_data: example_id = example["id"] + if len(example_ids) > 0 and example_id not in example_ids: + continue + example_question = example["question"] target_sub_questions = example.get("target_sub_questions", []) num_target_sub_questions = len(target_sub_questions) @@ -58,16 +62,21 @@ end_time = datetime.datetime.now() duration = end_time - start_time - chunk_expansion_ratio = ( - question_result["initial_agent_stats"] - .get("agent_effectiveness", {}) - .get("utilized_chunk_ratio", None) - ) - support_effectiveness_ratio = ( - question_result["initial_agent_stats"] - .get("agent_effectiveness", {}) - .get("support_ratio", None) - ) + if num_target_sub_questions > 0: + chunk_expansion_ratio = ( + question_result["initial_agent_stats"] + .get("agent_effectiveness", {}) + .get("utilized_chunk_ratio", None) + ) + support_effectiveness_ratio = ( + question_result["initial_agent_stats"] + .get("agent_effectiveness", {}) + .get("support_ratio", None) + ) + else: + chunk_expansion_ratio = None + support_effectiveness_ratio = None + generated_sub_questions = question_result.get("generated_sub_questions", []) num_generated_sub_questions = len(generated_sub_questions) base_answer = question_result["initial_base_answer"].split("==")[-1] From ca3f3beabe8c9f77e251122b700cddeed06c5e9e Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 24 Dec 2024 20:21:34 -0800 Subject: [PATCH 23/78] initial mypy changes --- .vscode/env_template.txt | 6 ++ .../nodes/doc_verification.py | 2 +- .../nodes/format_results.py | 63 ++++++++++++------- .../agent_search/expanded_retrieval/states.py | 7 ++- .../onyx/agent_search/main/graph_builder.py | 4 -- .../shared_graph_utils/calculations.py | 25 ++++---- .../agent_search/shared_graph_utils/models.py | 19 +++++- backend/onyx/configs/dev_configs.py | 41 ++++++++++-- 8 files changed, 119 insertions(+), 48 deletions(-) diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index 89faca0abf..74a83aba6b 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -49,3 +49,9 @@ BING_API_KEY= # Enable the full set of Danswer Enterprise Edition features # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development) ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False + +# Agent Search configs # TODO: Remove give proper namings +AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort +AGENT_RERANKING_STATS=True +AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20 +AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20 diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index ea7c000eff..aec6b68ad6 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -34,7 +34,7 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: fast_llm = state["fast_llm"] response = json.loads( - fast_llm.invoke(msg, structured_response_format=BinaryDecision).content + str(fast_llm.invoke(msg, structured_response_format=BinaryDecision).content) ) # response_string = response.content.get("decision", "no").lower() diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 7d9d660263..4022649027 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -7,17 +7,22 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.expanded_retrieval.states import InferenceSection from onyx.agent_search.expanded_retrieval.states import QueryResult +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats def _calculate_sub_question_retrieval_stats( verified_documents: list[InferenceSection], expanded_retrieval_results: list[QueryResult], -) -> dict[str, float | int]: - chunk_scores = defaultdict(lambda: defaultdict(list)) +) -> AgentChunkStats: + chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict( + lambda: defaultdict(list) + ) + for expanded_retrieval_result in expanded_retrieval_results: for doc in expanded_retrieval_result.search_results: doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" - chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) + if doc.center_chunk.score is not None: + chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) verified_doc_chunk_ids = [ f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}" @@ -25,37 +30,51 @@ def _calculate_sub_question_retrieval_stats( ] dismissed_doc_chunk_ids = [] - raw_chunk_stats = defaultdict(float) + raw_chunk_stats_counts: dict[str, int] = defaultdict(int) + raw_chunk_stats_scores: dict[str, float] = defaultdict(float) for doc_chunk_id, chunk_data in chunk_scores.items(): if doc_chunk_id in verified_doc_chunk_ids: - raw_chunk_stats["verified_count"] += 1 - raw_chunk_stats["verified_scores"] += np.mean(chunk_data["score"]) + raw_chunk_stats_counts["verified_count"] += 1 + + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + raw_chunk_stats_scores["verified_scores"] += float( + np.mean(valid_chunk_scores) + ) else: - raw_chunk_stats["rejected_count"] += 1 - raw_chunk_stats["rejected_scores"] += np.mean(chunk_data["score"]) + raw_chunk_stats_counts["rejected_count"] += 1 + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + raw_chunk_stats_scores["rejected_scores"] += float( + np.mean(valid_chunk_scores) + ) dismissed_doc_chunk_ids.append(doc_chunk_id) - if raw_chunk_stats["verified_count"] == 0: - verified_avg_scores = 0 + if raw_chunk_stats_counts["verified_count"] == 0: + verified_avg_scores = 0.0 else: - verified_avg_scores = ( - raw_chunk_stats["verified_scores"] / raw_chunk_stats["verified_count"] + verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float( + raw_chunk_stats_counts["verified_count"] ) - rejected_scores = raw_chunk_stats.get("rejected_scores", None) + rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None) if rejected_scores is not None: - rejected_avg_scores = rejected_scores / raw_chunk_stats["rejected_count"] + rejected_avg_scores = rejected_scores / float( + raw_chunk_stats_counts["rejected_count"] + ) else: rejected_avg_scores = None - chunk_stats = { - "verified_count": raw_chunk_stats["verified_count"], - "verified_avg_scores": verified_avg_scores, - "rejected_count": raw_chunk_stats["rejected_count"], - "rejected_avg_scores": rejected_avg_scores, - "verified_doc_chunk_ids": verified_doc_chunk_ids, - "dismissed_doc_chunk_ids": dismissed_doc_chunk_ids, - } + chunk_stats = AgentChunkStats( + verified_count=raw_chunk_stats_counts["verified_count"], + verified_avg_scores=verified_avg_scores, + rejected_count=raw_chunk_stats_counts["rejected_count"], + rejected_avg_scores=rejected_avg_scores, + verified_doc_chunk_ids=verified_doc_chunk_ids, + dismissed_doc_chunk_ids=dismissed_doc_chunk_ids, + ) return chunk_stats diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 55dae753e5..ab8f20f87f 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -1,11 +1,12 @@ from operator import add from typing import Annotated -from typing import Any from typing import TypedDict from pydantic import BaseModel from onyx.agent_search.core_state import CoreState +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -15,13 +16,13 @@ class QueryResult(BaseModel): query: str search_results: list[InferenceSection] - stats: dict[str, Any] | None + stats: RetrievalFitStats | None class ExpandedRetrievalResult(BaseModel): expanded_queries_results: list[QueryResult] all_documents: list[InferenceSection] - sub_question_retrieval_stats: dict + sub_question_retrieval_stats: AgentChunkStats ### States ### diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index f23ee9176f..940333ac6f 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -123,10 +123,6 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: primary_llm, fast_llm = get_default_llms() with get_session_context_manager() as db_session: - durations = [] - chunk_expansion_ratios = [] - support_effectiveness_ratios = [] - search_request = SearchRequest(query="Who created Excel?") inputs = MainInput( diff --git a/backend/onyx/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agent_search/shared_graph_utils/calculations.py index 93947c252f..6e5a6c0a37 100644 --- a/backend/onyx/agent_search/shared_graph_utils/calculations.py +++ b/backend/onyx/agent_search/shared_graph_utils/calculations.py @@ -1,7 +1,7 @@ import numpy as np -from onyx.agent_search.shared_graph_utils.models import FitScoreMetrics -from onyx.agent_search.shared_graph_utils.models import FitScores +from onyx.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics +from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.context.search.models import InferenceSection from onyx.utils.logger import setup_logger @@ -28,7 +28,7 @@ def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float: def get_fit_scores( pre_reranked_results: list[InferenceSection], post_reranked_results: list[InferenceSection], -) -> dict | None: +) -> RetrievalFitStats | None: """ Calculate retrieval metrics for search purposes """ @@ -41,25 +41,28 @@ def get_fit_scores( "reranked": post_reranked_results, } - fit_eval = {} - - fit_eval: FitScores = FitScores( + fit_eval: RetrievalFitStats = RetrievalFitStats( fit_score_lift=0, rerank_effect=0, fit_scores={ - "initial": FitScoreMetrics(scores={}, chunk_ids=[]), - "reranked": FitScoreMetrics(scores={}, chunk_ids=[]), + "initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]), + "reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]), }, ) - # logger.error("wwwww") - for rank_type, docs in ranked_sections.items(): print(f"rank_type: {rank_type}") for i in [1, 5, 10]: fit_eval.fit_scores[rank_type].scores[str(i)] = ( - sum([doc.center_chunk.score for doc in docs[:i]]) / i + sum( + [ + float(doc.center_chunk.score) + for doc in docs[:i] + if doc.center_chunk.score is not None + ] + ) + / i ) fit_eval.fit_scores[rank_type].scores["fit_score"] = ( diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py index 198ae0c475..28044637cd 100644 --- a/backend/onyx/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -17,12 +17,25 @@ class BinaryDecisionWithReasoning(BaseModel): decision: Literal["yes", "no"] -class FitScoreMetrics(BaseModel): +class RetrievalFitScoreMetrics(BaseModel): scores: dict[str, float] chunk_ids: list[str] -class FitScores(BaseModel): +class RetrievalFitStats(BaseModel): fit_score_lift: float rerank_effect: float - fit_scores: dict[str, FitScoreMetrics] + fit_scores: dict[str, RetrievalFitScoreMetrics] + + +class AgentChunkScores(BaseModel): + scores: dict[str, dict[str, list[int | float]]] + + +class AgentChunkStats(BaseModel): + verified_count: int | None + verified_avg_scores: float | None + rejected_count: int | None + rejected_avg_scores: float | None + verified_doc_chunk_ids: list[str] + dismissed_doc_chunk_ids: list[str] diff --git a/backend/onyx/configs/dev_configs.py b/backend/onyx/configs/dev_configs.py index c9a018e8e8..bb8256a18b 100644 --- a/backend/onyx/configs/dev_configs.py +++ b/backend/onyx/configs/dev_configs.py @@ -1,13 +1,46 @@ import os -from backend.onyx.configs.chat_configs import NUM_RETURNED_HITS +from .chat_configs import NUM_RETURNED_HITS ##### # Agent Configs ##### -AGENT_TEST = os.environ.get("AGENT_TEST", False) -AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS = os.environ.get( - "MAX_AGENT_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS +AGENT_RETRIEVAL_STATS = os.environ.get("AGENT_RETRIEVAL_STATS", False) +if AGENT_RETRIEVAL_STATS == "True": + AGENT_RETRIEVAL_STATS = True +elif AGENT_RETRIEVAL_STATS: + AGENT_RETRIEVAL_STATS = False + +AGENT_MAX_QUERY_RETRIEVAL_RESULTS = os.environ.get( + "AGENT_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS ) + +try: + atmqrr = int(AGENT_MAX_QUERY_RETRIEVAL_RESULTS) + AGENT_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr +except ValueError: + raise ValueError( + f"MAX_AGENT_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_MAX_QUERY_RETRIEVAL_RESULTS}" + ) + + +# Reranking agent configs +AGENT_RERANKING_STATS = os.environ.get("AGENT_RERANKING_TEST", False) +if AGENT_RERANKING_STATS == "True": + AGENT_RERANKING_STATS = True +elif AGENT_RERANKING_STATS: + AGENT_RERANKING_STATS = False + +AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = os.environ.get( + "AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS +) + +try: + atmqrr = int(AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS) + AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr +except ValueError: + raise ValueError( + f"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS}" + ) From c4af11c19b9c275edde0aea98599fa6a25b44f8e Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Wed, 25 Dec 2024 16:15:55 -0800 Subject: [PATCH 24/78] core mypy resolutions --- .../agent_search/answer_question/states.py | 6 +- .../expanded_retrieval/nodes/doc_reranking.py | 14 +-- .../expanded_retrieval/nodes/doc_retrieval.py | 12 +-- .../nodes/doc_verification.py | 12 +-- .../agent_search/expanded_retrieval/states.py | 2 +- .../main/nodes/generate_initial_answer.py | 96 ++++++++++--------- backend/onyx/agent_search/main/states.py | 16 +--- .../shared_graph_utils/calculations.py | 8 +- .../agent_search/shared_graph_utils/models.py | 6 ++ .../shared_graph_utils/prompts.py | 5 +- backend/onyx/configs/dev_configs.py | 35 ++++--- 11 files changed, 111 insertions(+), 101 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index a58d4439be..46c96e589c 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -6,10 +6,10 @@ from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection - ### Models ### @@ -23,7 +23,7 @@ class QuestionAnswerResults(BaseModel): quality: str expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] - sub_question_retrieval_stats: dict + sub_question_retrieval_stats: AgentChunkStats ### States ### @@ -43,7 +43,7 @@ class QAGenerationUpdate(TypedDict): class RetrievalIngestionUpdate(TypedDict): expanded_retrieval_results: list[QueryResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: dict + sub_question_retrieval_stats: AgentChunkStats ## Graph Input State diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index ffc90a9b83..b90ed6b40a 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -1,15 +1,15 @@ from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_RERANKING_STATS +from onyx.context.search.pipeline import InferenceSection from onyx.context.search.pipeline import retrieval_preprocessing from onyx.context.search.pipeline import search_postprocessing from onyx.context.search.pipeline import SearchRequest def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: - AGENT_TEST = True - AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS = 10 - verified_documents = state["verified_documents"] # Rerank post retrieval and verification. First, create a search query @@ -32,12 +32,14 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: 0 ] # only get the reranked szections, not the SectionRelevancePiece - if AGENT_TEST: + if AGENT_RERANKING_STATS: fit_scores = get_fit_scores(verified_documents, reranked_documents) else: fit_scores = None return DocRerankingUpdate( - reranked_documents=reranked_documents[:AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS], - fit_scores=fit_scores, + reranked_documents=[ + doc for doc in reranked_documents if type(doc) == InferenceSection + ][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS], + sub_question_retrieval_stats=fit_scores, ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 25b1560ca5..e5e7015e71 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -2,8 +2,8 @@ from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.expanded_retrieval.states import RetrievalInput from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores -from onyx.configs.dev_configs import AGENT_TEST -from onyx.configs.dev_configs import AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline @@ -34,14 +34,12 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: db_session=state["db_session"], ) - retrieved_docs = search_results._get_sections()[ - :AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS - ] + retrieved_docs = search_results._get_sections()[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] - if AGENT_TEST: + if AGENT_RETRIEVAL_STATS: fit_scores = get_fit_scores( retrieved_docs, - search_results.reranked_sections[:AGENT_TEST_MAX_QUERY_RETRIEVAL_RESULTS], + search_results.reranked_sections[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS], ) else: fit_scores = None diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index aec6b68ad6..7c5579619f 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -1,10 +1,7 @@ -import json - from langchain_core.messages import HumanMessage from onyx.agent_search.expanded_retrieval.states import DocVerificationInput from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate -from onyx.agent_search.shared_graph_utils.models import BinaryDecision from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT @@ -33,16 +30,11 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: ] fast_llm = state["fast_llm"] - response = json.loads( - str(fast_llm.invoke(msg, structured_response_format=BinaryDecision).content) - ) - # response_string = response.content.get("decision", "no").lower() - # Convert string response to proper dictionary format - # decision_dict = {"decision": response.content.lower()} + response = fast_llm.invoke(msg) verified_documents = [] - if response["decision"] == "yes": + if "yes" in response.content.lower(): verified_documents.append(doc_to_verify) return DocVerificationUpdate( diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index ab8f20f87f..677dbb8c8f 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -35,7 +35,7 @@ class DocVerificationUpdate(TypedDict): class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: Annotated[list[dict[str, float | int]], add] + sub_question_retrieval_stats: RetrievalFitStats | None class QueryExpansionUpdate(TypedDict): diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 1cd1863852..2522eb7b77 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -1,11 +1,10 @@ -import json - -from backend.onyx.agent_search.answer_question.states import QuestionAnswerResults from langchain_core.messages import HumanMessage -from onyx.agent_search.main.states import AgentStats +from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from onyx.agent_search.shared_graph_utils.prompts import ( @@ -15,77 +14,85 @@ def _calculate_initial_agent_stats( - decomp_answer_results: list[QuestionAnswerResults], original_question_stats: dict -) -> AgentStats: - initial_agent_dict = { - "sub_questions": {}, - "original_question": {}, - "agent_effectiveness": {}, - } + decomp_answer_results: list[QuestionAnswerResults], + original_question_stats: AgentChunkStats, +) -> InitialAgentResultStats: + initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats( + sub_questions={}, + original_question={}, + agent_effectiveness={}, + ) - orig_verified = original_question_stats["verified_count"] - orig_support_score = original_question_stats["verified_avg_scores"] + orig_verified = original_question_stats.verified_count + orig_support_score = original_question_stats.verified_avg_scores verified_document_chunk_ids = [] - support_scores = 0 + support_scores = 0.0 for decomp_answer_result in decomp_answer_results: verified_document_chunk_ids += ( - decomp_answer_result.sub_question_retrieval_stats["verified_doc_chunk_ids"] + decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids ) - support_scores += decomp_answer_result.sub_question_retrieval_stats[ - "verified_avg_scores" - ] + if ( + decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores + is not None + ): + support_scores += ( + decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores + ) verified_document_chunk_ids = list(set(verified_document_chunk_ids)) # Calculate sub-question stats - if verified_document_chunk_ids: - sub_question_stats = { + if ( + verified_document_chunk_ids + and len(verified_document_chunk_ids) > 0 + and support_scores is not None + ): + sub_question_stats: dict[str, float | int | None] = { "num_verified_documents": len(verified_document_chunk_ids), - "verified_avg_score": support_scores / len(decomp_answer_results), + "verified_avg_score": float(support_scores / len(decomp_answer_results)), } else: sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None} - initial_agent_dict["sub_questions"].update(sub_question_stats) + + initial_agent_result_stats.sub_questions.update(sub_question_stats) # Get original question stats - initial_agent_dict["original_question"].update( + initial_agent_result_stats.original_question.update( { - "num_verified_documents": original_question_stats.get("verified_count", 0), - "verified_avg_score": original_question_stats.get( - "verified_avg_scores", None - ), + "num_verified_documents": original_question_stats.verified_count, + "verified_avg_score": original_question_stats.verified_avg_scores, } ) # Calculate chunk utilization ratio - sub_verified = initial_agent_dict["sub_questions"]["num_verified_documents"] + sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"] - chunk_ratio = None - if orig_verified > 0: - chunk_ratio = sub_verified / orig_verified if sub_verified > 0 else 0 - elif sub_verified > 0: - chunk_ratio = 10 + chunk_ratio: float | None = None + if sub_verified is not None and orig_verified is not None and orig_verified > 0: + chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0 + elif sub_verified is not None and sub_verified > 0: + chunk_ratio = 10.0 - initial_agent_dict["agent_effectiveness"]["utilized_chunk_ratio"] = chunk_ratio + initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio if ( orig_support_score is None - and initial_agent_dict["sub_questions"]["verified_avg_score"] is None + and initial_agent_result_stats.sub_questions["verified_avg_score"] is None ): - initial_agent_dict["agent_effectiveness"]["support_ratio"] = None + initial_agent_result_stats.agent_effectiveness["support_ratio"] = None elif orig_support_score is None: - initial_agent_dict["agent_effectiveness"]["support_ratio"] = 10 - elif initial_agent_dict["sub_questions"]["verified_avg_score"] is None: - initial_agent_dict["agent_effectiveness"]["support_ratio"] = 0 + initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10 + elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0 else: - initial_agent_dict["agent_effectiveness"]["support_ratio"] = ( - initial_agent_dict["sub_questions"]["verified_avg_score"] + initial_agent_result_stats.agent_effectiveness["support_ratio"] = ( + initial_agent_result_stats.sub_questions["verified_avg_score"] / orig_support_score ) - return initial_agent_dict + return initial_agent_result_stats def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: @@ -160,8 +167,9 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") - print(json.dumps(initial_agent_stats, indent=4)) - + print(initial_agent_stats.original_question) + print(initial_agent_stats.sub_questions) + print(initial_agent_stats.agent_effectiveness) print("\n\n ---INITIAL AGENT ANSWER END---\n\n") return InitialAnswerUpdate( diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 48cf31faa0..a783fb463e 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -2,22 +2,14 @@ from typing import Annotated from typing import TypedDict -from pydantic import BaseModel - from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection -## Models - - -class AgentStats(BaseModel): - sub_question_stats: list[dict[str, float | int]] - original_question_stats: dict[str, float | int] - agent_stats: dict[str, float | int] - ### States ### @@ -34,7 +26,7 @@ class InitialAnswerBASEUpdate(TypedDict): class InitialAnswerUpdate(TypedDict): initial_answer: str - initial_agent_stats: dict + initial_agent_stats: InitialAgentResultStats generated_sub_questions: list[str] @@ -48,7 +40,7 @@ class ExpandedRetrievalUpdate(TypedDict): list[InferenceSection], dedup_inference_sections ] original_question_retrieval_results: list[QueryResult] - sub_question_retrieval_stats: dict + sub_question_retrieval_stats: AgentChunkStats ## Graph Input State diff --git a/backend/onyx/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agent_search/shared_graph_utils/calculations.py index 6e5a6c0a37..57deffe28c 100644 --- a/backend/onyx/agent_search/shared_graph_utils/calculations.py +++ b/backend/onyx/agent_search/shared_graph_utils/calculations.py @@ -2,6 +2,7 @@ from onyx.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.chat.models import SectionRelevancePiece from onyx.context.search.models import InferenceSection from onyx.utils.logger import setup_logger @@ -27,7 +28,7 @@ def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float: def get_fit_scores( pre_reranked_results: list[InferenceSection], - post_reranked_results: list[InferenceSection], + post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece], ) -> RetrievalFitStats | None: """ Calculate retrieval metrics for search purposes @@ -59,7 +60,8 @@ def get_fit_scores( [ float(doc.center_chunk.score) for doc in docs[:i] - if doc.center_chunk.score is not None + if type(doc) == InferenceSection + and doc.center_chunk.score is not None ] ) / i @@ -80,7 +82,7 @@ def get_fit_scores( ].scores["1"] fit_eval.fit_scores[rank_type].chunk_ids = [ - unique_chunk_id(doc) for doc in docs + unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection ] fit_eval.fit_score_lift = ( diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py index 28044637cd..5193b5dd2b 100644 --- a/backend/onyx/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -39,3 +39,9 @@ class AgentChunkStats(BaseModel): rejected_avg_scores: float | None verified_doc_chunk_ids: list[str] dismissed_doc_chunk_ids: list[str] + + +class InitialAgentResultStats(BaseModel): + sub_questions: dict[str, float | int | None] + original_question: dict[str, float | int | None] + agent_effectiveness: dict[str, float | int | None] diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 918e14c8f7..07b935d91b 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -87,10 +87,9 @@ {question} -- -Please answer with 'yes' or 'no' and format your answer as a json object with the following format: -{{"decision": }} +Please answer with 'yes' or 'no': -AANSWER: +Answer: """ diff --git a/backend/onyx/configs/dev_configs.py b/backend/onyx/configs/dev_configs.py index bb8256a18b..a894880cf8 100644 --- a/backend/onyx/configs/dev_configs.py +++ b/backend/onyx/configs/dev_configs.py @@ -7,18 +7,23 @@ # Agent Configs ##### -AGENT_RETRIEVAL_STATS = os.environ.get("AGENT_RETRIEVAL_STATS", False) -if AGENT_RETRIEVAL_STATS == "True": +agent_retrieval_stats_os: bool | str | None = os.environ.get( + "AGENT_RETRIEVAL_STATS", False +) + +AGENT_RETRIEVAL_STATS: bool = False +if isinstance(agent_retrieval_stats_os, str) and agent_retrieval_stats_os == "True": + AGENT_RETRIEVAL_STATS = True +elif isinstance(agent_retrieval_stats_os, bool) and agent_retrieval_stats_os: AGENT_RETRIEVAL_STATS = True -elif AGENT_RETRIEVAL_STATS: - AGENT_RETRIEVAL_STATS = False -AGENT_MAX_QUERY_RETRIEVAL_RESULTS = os.environ.get( +agent_max_query_retrieval_results_os: int | str = os.environ.get( "AGENT_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS ) +AGENT_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS try: - atmqrr = int(AGENT_MAX_QUERY_RETRIEVAL_RESULTS) + atmqrr = int(agent_max_query_retrieval_results_os) AGENT_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr except ValueError: raise ValueError( @@ -27,18 +32,24 @@ # Reranking agent configs -AGENT_RERANKING_STATS = os.environ.get("AGENT_RERANKING_TEST", False) -if AGENT_RERANKING_STATS == "True": +agent_reranking_stats_os: bool | str | None = os.environ.get( + "AGENT_RERANKING_TEST", False +) +AGENT_RERANKING_STATS: bool = False +if isinstance(agent_reranking_stats_os, str) and agent_reranking_stats_os == "True": AGENT_RERANKING_STATS = True -elif AGENT_RERANKING_STATS: - AGENT_RERANKING_STATS = False +elif isinstance(agent_reranking_stats_os, bool) and agent_reranking_stats_os: + AGENT_RERANKING_STATS = True + -AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = os.environ.get( +agent_reranking_max_query_retrieval_results_os: int | str = os.environ.get( "AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS ) +AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS + try: - atmqrr = int(AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS) + atmqrr = int(agent_reranking_max_query_retrieval_results_os) AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr except ValueError: raise ValueError( From 21928133e0cde70b13d6e9c1ad6c6219a229688a Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 26 Dec 2024 10:59:10 -0800 Subject: [PATCH 25/78] tmp_state_model_sep --- .../agent_search/answer_question/states.py | 2 +- .../agent_search/expanded_retrieval/models.py | 19 ++++++++++++++++ .../expanded_retrieval/nodes/doc_retrieval.py | 2 +- .../nodes/format_results.py | 4 ++-- .../agent_search/expanded_retrieval/states.py | 22 ++----------------- backend/onyx/agent_search/main/edges.py | 2 +- backend/onyx/agent_search/main/states.py | 2 +- 7 files changed, 27 insertions(+), 26 deletions(-) create mode 100644 backend/onyx/agent_search/expanded_retrieval/models.py diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 46c96e589c..3cdd6016d3 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from onyx.agent_search.core_state import CoreState -from onyx.agent_search.expanded_retrieval.states import QueryResult +from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/agent_search/expanded_retrieval/models.py b/backend/onyx/agent_search/expanded_retrieval/models.py new file mode 100644 index 0000000000..4e8caa3605 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.context.search.models import InferenceSection + +### Models ### + + +class QueryResult(BaseModel): + query: str + search_results: list[InferenceSection] + stats: RetrievalFitStats | None + + +class ExpandedRetrievalResult(BaseModel): + expanded_queries_results: list[QueryResult] + all_documents: list[InferenceSection] + sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index e5e7015e71..db098b8b8b 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,5 +1,5 @@ +from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate -from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.expanded_retrieval.states import RetrievalInput from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 4022649027..72892dc227 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -2,11 +2,11 @@ import numpy as np +from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.expanded_retrieval.states import InferenceSection -from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 677dbb8c8f..2b01938cfa 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -2,28 +2,13 @@ from typing import Annotated from typing import TypedDict -from pydantic import BaseModel - from onyx.agent_search.core_state import CoreState -from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection -### Models ### - - -class QueryResult(BaseModel): - query: str - search_results: list[InferenceSection] - stats: RetrievalFitStats | None - - -class ExpandedRetrievalResult(BaseModel): - expanded_queries_results: list[QueryResult] - all_documents: list[InferenceSection] - sub_question_retrieval_stats: AgentChunkStats - ### States ### ## Update States @@ -40,7 +25,6 @@ class DocRerankingUpdate(TypedDict): class QueryExpansionUpdate(TypedDict): expanded_queries: list[str] - question: str class DocRetrievalUpdate(TypedDict): @@ -81,8 +65,6 @@ class ExpandedRetrievalOutput(TypedDict): class DocVerificationInput(ExpandedRetrievalInput): doc_to_verify: InferenceSection - query_to_retrieve: str - question: str class RetrievalInput(ExpandedRetrievalInput): diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 492214f7c2..f04fb097b3 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -37,8 +37,8 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: Send( "initial_retrieval", ExpandedRetrievalInput( - **extract_core_fields(state), question=state["search_request"].query, + **extract_core_fields(state), ), ) ] diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index a783fb463e..2cd8d26831 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -4,7 +4,7 @@ from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import CoreState -from onyx.agent_search.expanded_retrieval.states import QueryResult +from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections From 901d8c22c438c0384fea9e24fecf06fe155e6a13 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 26 Dec 2024 14:08:34 -0800 Subject: [PATCH 26/78] further clean-up --- .../agent_search/answer_question/models.py | 78 +++++++++++++++++++ .../agent_search/expanded_retrieval/edges.py | 1 + .../nodes/verification_kickoff.py | 2 - .../agent_search/expanded_retrieval/states.py | 25 +++--- 4 files changed, 92 insertions(+), 14 deletions(-) create mode 100644 backend/onyx/agent_search/answer_question/models.py diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_question/models.py new file mode 100644 index 0000000000..3cdd6016d3 --- /dev/null +++ b/backend/onyx/agent_search/answer_question/models.py @@ -0,0 +1,78 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from pydantic import BaseModel + +from onyx.agent_search.core_state import CoreState +from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + +### Models ### + + +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] + + +class QuestionAnswerResults(BaseModel): + question: str + answer: str + quality: str + expanded_retrieval_results: list[QueryResult] + documents: list[InferenceSection] + sub_question_retrieval_stats: AgentChunkStats + + +### States ### + +## Update States + + +class QACheckUpdate(TypedDict): + answer_quality: str + + +class QAGenerationUpdate(TypedDict): + answer: str + # answer_stat: AnswerStats + + +class RetrievalIngestionUpdate(TypedDict): + expanded_retrieval_results: list[QueryResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: AgentChunkStats + + +## Graph Input State + + +class AnswerQuestionInput(CoreState): + question: str + + +## Graph State + + +class AnswerQuestionState( + AnswerQuestionInput, + QAGenerationUpdate, + QACheckUpdate, + RetrievalIngestionUpdate, +): + pass + + +## Graph Output State + + +class AnswerQuestionOutput(TypedDict): + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index cd5f2c6175..04a9e37cae 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -16,6 +16,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab "doc_retrieval", RetrievalInput( query_to_retrieve=query, + question=question, **extract_core_fields(state), ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index 05b2465263..3f1114ebae 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -13,8 +13,6 @@ def verification_kickoff( state: ExpandedRetrievalState, ) -> Command[Literal["doc_verification"]]: - print(f"verification_kickoff state: {state.keys()}") - documents = state["retrieved_documents"] return Command( update={}, diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 2b01938cfa..b02a831004 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -11,32 +11,33 @@ ### States ### -## Update States +## Graph Input State -class DocVerificationUpdate(TypedDict): - verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class ExpandedRetrievalInput(CoreState): + question: str -class DocRerankingUpdate(TypedDict): - reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: RetrievalFitStats | None + +## Update/Return States class QueryExpansionUpdate(TypedDict): expanded_queries: list[str] +class DocVerificationUpdate(TypedDict): + verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + class DocRetrievalUpdate(TypedDict): expanded_retrieval_results: Annotated[list[QueryResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] -## Graph Input State - - -class ExpandedRetrievalInput(CoreState): - question: str +class DocRerankingUpdate(TypedDict): + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: RetrievalFitStats | None ## Graph State @@ -45,10 +46,10 @@ class ExpandedRetrievalInput(CoreState): class ExpandedRetrievalState( # This includes the core state ExpandedRetrievalInput, + QueryExpansionUpdate, DocRetrievalUpdate, DocVerificationUpdate, DocRerankingUpdate, - QueryExpansionUpdate, ): pass From cc76486d21c8332bc4a65ebcaa48c360d9283134 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sat, 28 Dec 2024 09:08:30 -0800 Subject: [PATCH 27/78] added streaming and small fixes --- .../agent_search/answer_question/models.py | 58 --------- .../agent_search/answer_question/states.py | 22 +--- .../onyx/agent_search/main/graph_builder.py | 2 +- backend/onyx/agent_search/run_graph.py | 112 +++++++++++++++--- 4 files changed, 97 insertions(+), 97 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_question/models.py index 3cdd6016d3..788ef54106 100644 --- a/backend/onyx/agent_search/answer_question/models.py +++ b/backend/onyx/agent_search/answer_question/models.py @@ -1,13 +1,7 @@ -from operator import add -from typing import Annotated -from typing import TypedDict - from pydantic import BaseModel -from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection ### Models ### @@ -24,55 +18,3 @@ class QuestionAnswerResults(BaseModel): expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] sub_question_retrieval_stats: AgentChunkStats - - -### States ### - -## Update States - - -class QACheckUpdate(TypedDict): - answer_quality: str - - -class QAGenerationUpdate(TypedDict): - answer: str - # answer_stat: AnswerStats - - -class RetrievalIngestionUpdate(TypedDict): - expanded_retrieval_results: list[QueryResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: AgentChunkStats - - -## Graph Input State - - -class AnswerQuestionInput(CoreState): - question: str - - -## Graph State - - -class AnswerQuestionState( - AnswerQuestionInput, - QAGenerationUpdate, - QACheckUpdate, - RetrievalIngestionUpdate, -): - pass - - -## Graph Output State - - -class AnswerQuestionOutput(TypedDict): - """ - This is a list of results even though each call of this subgraph only returns one result. - This is because if we parallelize the answer query subgraph, there will be multiple - results in a list so the add operator is used to add them together. - """ - - answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 3cdd6016d3..8081451eae 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -2,35 +2,15 @@ from typing import Annotated from typing import TypedDict -from pydantic import BaseModel - +from onyx.agent_search.answer_question.models import QuestionAnswerResults from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection -### Models ### - - -class AnswerRetrievalStats(BaseModel): - answer_retrieval_stats: dict[str, float | int] - - -class QuestionAnswerResults(BaseModel): - question: str - answer: str - quality: str - expanded_retrieval_results: list[QueryResult] - documents: list[InferenceSection] - sub_question_retrieval_stats: AgentChunkStats - - -### States ### ## Update States - - class QACheckUpdate(TypedDict): answer_quality: str diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 940333ac6f..cb6d090daa 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -82,7 +82,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_conditional_edges( source="base_decomp", path=parallelize_decompozed_answer_queries, - path_map=["answer_query", "ingest_answers"], + path_map=["answer_query"], ) graph.add_edge( start_key="answer_query", diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 9a93dbba64..b060cba571 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,27 +1,105 @@ +import asyncio +from collections.abc import AsyncIterable +from collections.abc import Iterable + +from langchain_core.runnables.schema import StreamEvent +from langgraph.graph.state import CompiledStateGraph + from onyx.agent_search.main.graph_builder import main_graph_builder +from onyx.agent_search.main.states import MainInput from onyx.chat.answer import AnswerStream +from onyx.chat.models import AnswerQuestionPossibleReturn +from onyx.chat.models import OnyxAnswerPiece +from onyx.context.search.models import SearchRequest +from onyx.db.engine import get_session_context_manager from onyx.llm.interfaces import LLM -from onyx.tools.tool import Tool +from onyx.tools.models import ToolResponse +from onyx.tools.tool_runner import ToolCallKickoff + + +def _parse_agent_event( + event: StreamEvent, +) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None: + """ + Parse the event into a typed object. + Return None if we are not interested in the event. + """ + if event["name"] == "LangGraph": + return None + event_type = event["event"] + if event_type == "tool_call_kickoff": + return ToolCallKickoff(**event["data"]) + elif event_type == "tool_response": + return ToolResponse(**event["data"]) + elif event_type == "on_chat_model_stream": + return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content) + return None + + +def _manage_async_event_streaming( + compiled_graph: CompiledStateGraph, + graph_input: MainInput, +) -> Iterable[StreamEvent]: + async def _run_async_event_stream() -> AsyncIterable[StreamEvent]: + async for event in compiled_graph.astream_events( + input=graph_input, + # indicating v2 here deserves further scrutiny + version="v2", + ): + yield event + + # This might be able to be simplified + def _yield_async_to_sync() -> Iterable[StreamEvent]: + loop = asyncio.new_event_loop() + try: + # Get the async generator + async_gen = _run_async_event_stream() + # Convert to AsyncIterator + async_iter = async_gen.__aiter__() + while True: + try: + # Create a coroutine by calling anext with the async iterator + next_coro = anext(async_iter) + # Run the coroutine to get the next event + event = loop.run_until_complete(next_coro) + yield event + except StopAsyncIteration: + break + finally: + loop.close() + + return _yield_async_to_sync() def run_graph( - query: str, - llm: LLM, - tools: list[Tool], + compiled_graph: CompiledStateGraph, + search_request: SearchRequest, + primary_llm: LLM, + fast_llm: LLM, ) -> AnswerStream: - graph = main_graph_builder() - - inputs = { - "original_query": query, - "messages": [], - "tools": tools, - "llm": llm, - } - compiled_graph = graph.compile() - output = compiled_graph.invoke(input=inputs) - yield from output + with get_session_context_manager() as db_session: + input = MainInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + ) + for event in _manage_async_event_streaming( + compiled_graph=compiled_graph, graph_input=input + ): + if parsed_object := _parse_agent_event(event): + yield parsed_object if __name__ == "__main__": - pass - # run_graph("What is the capital of France?", llm, []) + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = main_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="what can you do with onyx or danswer?", + ) + for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): + print(output) From 0333ff648ad9e4729f0c6f4a36a4a6081594baa9 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sun, 29 Dec 2024 10:09:12 -0800 Subject: [PATCH 28/78] left or right --- .../agent_search/answer_question/edges.py | 7 +- .../agent_search/answer_question/models.py | 2 +- .../answer_question/nodes/answer_check.py | 2 +- .../nodes/answer_generation.py | 4 +- .../answer_question/nodes/format_answer.py | 12 +- .../answer_question/nodes/ingest_retrieval.py | 12 +- .../agent_search/answer_question/states.py | 4 +- backend/onyx/agent_search/base_lg_tests.py | 66 ++++++++ .../base_raw_search/graph_builder.py | 70 +++++++++ .../agent_search/base_raw_search/models.py | 20 +++ .../nodes/format_raw_search_results.py | 10 ++ .../nodes/generate_raw_search_data.py | 14 ++ .../agent_search/base_raw_search/states.py | 40 +++++ backend/onyx/agent_search/core_state.py | 32 ++++ .../deep_answer/nodes/answer_generation.py | 2 +- .../agent_search/expanded_retrieval/edges.py | 6 +- .../agent_search/expanded_retrieval/models.py | 5 +- .../expanded_retrieval/nodes/doc_reranking.py | 13 +- .../expanded_retrieval/nodes/doc_retrieval.py | 6 +- .../nodes/doc_verification.py | 3 +- .../nodes/expand_queries.py | 2 +- .../nodes/format_results.py | 6 + .../nodes/verification_kickoff.py | 9 +- .../agent_search/expanded_retrieval/states.py | 24 +-- backend/onyx/agent_search/main/edges.py | 7 +- .../onyx/agent_search/main/graph_builder.py | 142 ++++++++++++------ .../main/nodes/generate_initial_answer.py | 7 +- .../main/nodes/ingest_initial_retrieval.py | 12 +- .../main/nodes/prep_for_initial_retrieval.py | 12 ++ backend/onyx/agent_search/main/states.py | 6 +- backend/onyx/agent_search/run_graph.py | 21 ++- 31 files changed, 472 insertions(+), 106 deletions(-) create mode 100644 backend/onyx/agent_search/base_lg_tests.py create mode 100644 backend/onyx/agent_search/base_raw_search/graph_builder.py create mode 100644 backend/onyx/agent_search/base_raw_search/models.py create mode 100644 backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py create mode 100644 backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py create mode 100644 backend/onyx/agent_search/base_raw_search/states.py create mode 100644 backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index bdd9864e6e..1caeec6e9f 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -3,15 +3,18 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.core_state import extract_core_fields +from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: + print("sending to expanded retrieval via edge") + return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( - **extract_core_fields(state), + **in_subgraph_extract_core_fields(state), question=state["question"], + dummy="1" ), ) diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_question/models.py index 788ef54106..6ee67c9e36 100644 --- a/backend/onyx/agent_search/answer_question/models.py +++ b/backend/onyx/agent_search/answer_question/models.py @@ -17,4 +17,4 @@ class QuestionAnswerResults(BaseModel): quality: str expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] - sub_question_retrieval_stats: AgentChunkStats + sub_question_retrieval_stats: list[AgentChunkStats] diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py index 83cc46280f..6349552f34 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -16,7 +16,7 @@ def answer_check(state: AnswerQuestionState) -> QACheckUpdate: ) ] - fast_llm = state["fast_llm"] + fast_llm = state["subgraph_fast_llm"] response = list( fast_llm.stream( prompt=msg, diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 7e1e326b68..0403583567 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -18,12 +18,12 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: content=BASE_RAG_PROMPT.format( question=question, context=format_docs(docs), - original_question=state["search_request"].query, + original_question=state["subgraph_search_request"].query, ) ) ] - fast_llm = state["fast_llm"] + fast_llm = state["subgraph_fast_llm"] response = list( fast_llm.stream( prompt=msg, diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 95a0ac38bf..e748ac0c3f 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -4,6 +4,16 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: + sub_question_retrieval_stats = state["sub_question_retrieval_stats"] + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = [] + elif isinstance(sub_question_retrieval_stats, list): + sub_question_retrieval_stats = sub_question_retrieval_stats + if isinstance(sub_question_retrieval_stats[0], list): + sub_question_retrieval_stats = sub_question_retrieval_stats[0] + else: + sub_question_retrieval_stats = [sub_question_retrieval_stats] + return AnswerQuestionOutput( answer_results=[ QuestionAnswerResults( @@ -12,7 +22,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: answer=state["answer"], expanded_retrieval_results=state["expanded_retrieval_results"], documents=state["documents"], - sub_question_retrieval_stats=state["sub_question_retrieval_stats"], + sub_question_retrieval_stats=sub_question_retrieval_stats, ) ], ) diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py index 54830b1873..a7efd854fa 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -3,12 +3,18 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: + sub_question_retrieval_stats = state[ + "expanded_retrieval_result" + ].sub_question_retrieval_stats + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = [] + else: + sub_question_retrieval_stats = [sub_question_retrieval_stats] + return RetrievalIngestionUpdate( expanded_retrieval_results=state[ "expanded_retrieval_result" ].expanded_queries_results, documents=state["expanded_retrieval_result"].all_documents, - sub_question_retrieval_stats=state[ - "expanded_retrieval_result" - ].sub_question_retrieval_stats, + sub_question_retrieval_stats=sub_question_retrieval_stats, ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 8081451eae..28f4dc2134 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -3,7 +3,7 @@ from typing import TypedDict from onyx.agent_search.answer_question.models import QuestionAnswerResults -from onyx.agent_search.core_state import CoreState +from onyx.agent_search.core_state import SubgraphCoreState from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections @@ -29,7 +29,7 @@ class RetrievalIngestionUpdate(TypedDict): ## Graph Input State -class AnswerQuestionInput(CoreState): +class AnswerQuestionInput(SubgraphCoreState): question: str diff --git a/backend/onyx/agent_search/base_lg_tests.py b/backend/onyx/agent_search/base_lg_tests.py new file mode 100644 index 0000000000..e937219d36 --- /dev/null +++ b/backend/onyx/agent_search/base_lg_tests.py @@ -0,0 +1,66 @@ +from typing import TypedDict + +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + + +# The overall state of the graph (this is the public state shared across nodes) +class OverallState(TypedDict): + a: str + + +# Output from node_1 contains private data that is not part of the overall state +class Node1Output(TypedDict): + private_data: str + + +# Node 2 input only requests the private data available after node_1 +class Node2Input(TypedDict): + private_data: str + + +# The private data is only shared between node_1 and node_2 +def node_1(state: OverallState) -> Node2Input: + output = {"private_data": "set by node_1"} + print(f"Entered node `node_1`:\n\tInput: {state}.\n\tReturned: {output}") + return output + + +def node_2(state: Node2Input) -> OverallState: + output = {"a": "set by node_2"} + print(f"Entered node `node_2`:\n\tInput: {state}.\n\tReturned: {output}") + return output + + +# Node 3 only has access to the overall state (no access to private data from node_1) +def node_3(state: OverallState) -> OverallState: + output = {"a": "set by node_3"} + print(f"Entered node `node_3`:\n\tInput: {state}.\n\tReturned: {output}") + return output + + +# Build the state graph +builder = StateGraph(OverallState) +builder.add_node(node_1) # node_1 is the first node +builder.add_node( + node_2 +) # node_2 is the second node and accepts private data from node_1 +builder.add_node(node_3) # node_3 is the third node and does not see the private data +builder.add_edge(START, "node_1") # Start the graph with node_1 +builder.add_edge("node_1", "node_2") # Pass from node_1 to node_2 +builder.add_edge( + "node_2", "node_3" +) # Pass from node_2 to node_3 (only overall state is shared) +builder.add_edge("node_3", END) # End the graph after node_3 +graph = builder.compile() + +# Invoke the graph with the initial state +response = graph.invoke( + { + "a": "set at start", + } +) + +print() +print(f"Output of graph invocation: {response}") diff --git a/backend/onyx/agent_search/base_raw_search/graph_builder.py b/backend/onyx/agent_search/base_raw_search/graph_builder.py new file mode 100644 index 0000000000..5de90a8884 --- /dev/null +++ b/backend/onyx/agent_search/base_raw_search/graph_builder.py @@ -0,0 +1,70 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.base_raw_search.nodes.format_raw_search_results import ( + format_raw_search_results, +) +from onyx.agent_search.base_raw_search.nodes.generate_raw_search_data import ( + generate_raw_search_data, +) +from onyx.agent_search.base_raw_search.states import BaseRawSearchInput +from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.base_raw_search.states import BaseRawSearchState +from onyx.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) + + +def base_raw_search_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=BaseRawSearchState, + input=BaseRawSearchInput, + output=BaseRawSearchOutput, + ) + + ### Add nodes ### + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="generate_raw_search_data", + action=generate_raw_search_data, + ) + + graph.add_node( + node="expanded_retrieval_base_search", + action=expanded_retrieval, + ) + graph.add_node( + node="format_raw_search_results", + action=format_raw_search_results, + ) + + ### Add edges ### + + graph.add_edge(start_key=START, end_key="generate_raw_search_data") + + graph.add_edge( + start_key="generate_raw_search_data", + end_key="expanded_retrieval_base_search", + ) + graph.add_edge( + start_key="expanded_retrieval_base_search", + end_key="format_raw_search_results", + ) + + # graph.add_edge( + # start_key="expanded_retrieval_base_search", + # end_key=END, + # ) + + graph.add_edge( + start_key="format_raw_search_results", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + pass diff --git a/backend/onyx/agent_search/base_raw_search/models.py b/backend/onyx/agent_search/base_raw_search/models.py new file mode 100644 index 0000000000..6ee67c9e36 --- /dev/null +++ b/backend/onyx/agent_search/base_raw_search/models.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + +from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.context.search.models import InferenceSection + +### Models ### + + +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] + + +class QuestionAnswerResults(BaseModel): + question: str + answer: str + quality: str + expanded_retrieval_results: list[QueryResult] + documents: list[InferenceSection] + sub_question_retrieval_stats: list[AgentChunkStats] diff --git a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py new file mode 100644 index 0000000000..dfd2b47e6b --- /dev/null +++ b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py @@ -0,0 +1,10 @@ +from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput + + +def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput: + print("format_raw_search_results") + return BaseRawSearchOutput( + base_retrieval_results=[state["expanded_retrieval_result"]], + base_search_documents=[], + ) diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py new file mode 100644 index 0000000000..a09729a4b1 --- /dev/null +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -0,0 +1,14 @@ +from onyx.agent_search.core_state import CoreState +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput + + +def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: + print("generate_raw_search_data") + return ExpandedRetrievalInput( + subgraph_search_request=state["search_request"], + subgraph_primary_llm=state["primary_llm"], + subgraph_fast_llm=state["fast_llm"], + subgraph_db_session=state["db_session"], + question=state["search_request"].query, + dummy="7", + ) diff --git a/backend/onyx/agent_search/base_raw_search/states.py b/backend/onyx/agent_search/base_raw_search/states.py new file mode 100644 index 0000000000..920bab97a2 --- /dev/null +++ b/backend/onyx/agent_search/base_raw_search/states.py @@ -0,0 +1,40 @@ +from typing import TypedDict + +from onyx.agent_search.core_state import CoreState +from onyx.agent_search.core_state import SubgraphCoreState +from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult + + +## Update States + + +## Graph Input State + + +class BaseRawSearchInput(CoreState, SubgraphCoreState): + pass + + +## Graph Output State + + +class BaseRawSearchOutput(TypedDict): + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + # base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections] + # base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] + expanded_retrieval_result: ExpandedRetrievalResult + + +## Graph State + + +class BaseRawSearchState( + BaseRawSearchInput, + BaseRawSearchOutput, +): + pass diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index cbc8f3d5c4..6dd8d0f8b8 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -1,3 +1,5 @@ +from operator import add +from typing import Annotated from typing import TypedDict from typing import TypeVar @@ -18,12 +20,42 @@ class CoreState(TypedDict, total=False): # a single session for the entire agent search # is fine if we are only reading db_session: Session + log_messages: Annotated[list[str], add] + dummy: str + + +class SubgraphCoreState(TypedDict, total=False): + """ + This is the core state that is shared across all subgraphs. + """ + + subgraph_search_request: SearchRequest + subgraph_primary_llm: LLM + subgraph_fast_llm: LLM + # a single session for the entire agent search + # is fine if we are only reading + subgraph_db_session: Session # This ensures that the state passed in extends the CoreState T = TypeVar("T", bound=CoreState) +T_SUBGRAPH = TypeVar("T_SUBGRAPH", bound=SubgraphCoreState) def extract_core_fields(state: T) -> CoreState: filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__} return CoreState(**dict(filtered_dict)) # type: ignore + + +def extract_core_fields_for_subgraph(state: T) -> SubgraphCoreState: + filtered_dict = { + "subgraph_" + k: v for k, v in state.items() if k in CoreState.__annotations__ + } + return SubgraphCoreState(**dict(filtered_dict)) # type: ignore + + +def in_subgraph_extract_core_fields(state: T_SUBGRAPH) -> SubgraphCoreState: + filtered_dict = { + k: v for k, v in state.items() if k in SubgraphCoreState.__annotations__ + } + return SubgraphCoreState(**dict(filtered_dict)) diff --git a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py index f0a94b398a..67959efbd2 100644 --- a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py +++ b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py @@ -70,7 +70,7 @@ def final_stuff(state: MainState) -> dict[str, Any]: time_ordered_messages.sort() print("Message Log:") - print("\n".join(time_ordered_messages)) + # print("\n".join(time_ordered_messages)) initial_sub_qas = state["initial_sub_qas"] initial_sub_qa_list = [] diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 04a9e37cae..006d915ead 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -2,13 +2,13 @@ from langgraph.types import Send -from onyx.agent_search.core_state import extract_core_fields +from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]: - question = state.get("question", state["search_request"].query) + question = state.get("question", state["subgraph_search_request"].query) query_expansions = state.get("expanded_queries", []) + [question] return [ @@ -17,7 +17,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab RetrievalInput( query_to_retrieve=query, question=question, - **extract_core_fields(state), + **in_subgraph_extract_core_fields(state), ), ) for query in query_expansions diff --git a/backend/onyx/agent_search/expanded_retrieval/models.py b/backend/onyx/agent_search/expanded_retrieval/models.py index 4e8caa3605..25ab3b899c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/models.py +++ b/backend/onyx/agent_search/expanded_retrieval/models.py @@ -1,3 +1,6 @@ +from operator import add +from typing import Annotated + from pydantic import BaseModel from onyx.agent_search.shared_graph_utils.models import AgentChunkStats @@ -16,4 +19,4 @@ class QueryResult(BaseModel): class ExpandedRetrievalResult(BaseModel): expanded_queries_results: list[QueryResult] all_documents: list[InferenceSection] - sub_question_retrieval_stats: AgentChunkStats + sub_question_retrieval_stats: Annotated[list[AgentChunkStats], add] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index b90ed6b40a..0f1e182f9f 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -15,27 +15,28 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: # Rerank post retrieval and verification. First, create a search query # then create the list of reranked sections + question = state.get("question", state["subgraph_search_request"].query) _search_query = retrieval_preprocessing( - search_request=SearchRequest(query=state["question"]), + search_request=SearchRequest(query=question), user=None, - llm=state["fast_llm"], - db_session=state["db_session"], + llm=state["subgraph_fast_llm"], + db_session=state["subgraph_db_session"], ) reranked_documents = list( search_postprocessing( search_query=_search_query, retrieved_sections=verified_documents, - llm=state["fast_llm"], + llm=state["subgraph_fast_llm"], ) )[ 0 ] # only get the reranked szections, not the SectionRelevancePiece if AGENT_RERANKING_STATS: - fit_scores = get_fit_scores(verified_documents, reranked_documents) + fit_scores = [get_fit_scores(verified_documents, reranked_documents)] else: - fit_scores = None + fit_scores = [] return DocRerankingUpdate( reranked_documents=[ diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index db098b8b8b..4e4d44cada 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -20,8 +20,8 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: retrieved_documents: list[InferenceSection] """ - llm = state["primary_llm"] - fast_llm = state["fast_llm"] + llm = state["subgraph_primary_llm"] + fast_llm = state["subgraph_fast_llm"] query_to_retrieve = state["query_to_retrieve"] search_results = SearchPipeline( @@ -31,7 +31,7 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: user=None, llm=llm, fast_llm=fast_llm, - db_session=state["db_session"], + db_session=state["subgraph_db_session"], ) retrieved_docs = search_results._get_sections()[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index 7c5579619f..11574a9fe5 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -16,7 +16,6 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: verified_documents: list[InferenceSection] """ - state["search_request"].query question = state["question"] doc_to_verify = state["doc_to_verify"] document_content = doc_to_verify.combined_content @@ -29,7 +28,7 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: ) ] - fast_llm = state["fast_llm"] + fast_llm = state["subgraph_fast_llm"] response = fast_llm.invoke(msg) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py index 193d9b648f..87599ce726 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py @@ -9,7 +9,7 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: question = state.get("question") - llm: LLM = state["fast_llm"] + llm: LLM = state["subgraph_fast_llm"] msg = [ HumanMessage( diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 72892dc227..c292c14163 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -84,6 +84,12 @@ def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: verified_documents=state["verified_documents"], expanded_retrieval_results=state["expanded_retrieval_results"], ) + + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = [] + else: + sub_question_retrieval_stats = [sub_question_retrieval_stats] + return ExpandedRetrievalOutput( expanded_retrieval_result=ExpandedRetrievalResult( expanded_queries_results=state["expanded_retrieval_results"], diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index 3f1114ebae..00adf495a3 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -3,7 +3,7 @@ from langgraph.types import Command from langgraph.types import Send -from onyx.agent_search.core_state import extract_core_fields +from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( DocVerificationInput, ) @@ -14,6 +14,9 @@ def verification_kickoff( state: ExpandedRetrievalState, ) -> Command[Literal["doc_verification"]]: documents = state["retrieved_documents"] + verification_question = state.get( + "question", state["subgraph_search_request"].query + ) return Command( update={}, goto=[ @@ -21,8 +24,8 @@ def verification_kickoff( node="doc_verification", arg=DocVerificationInput( doc_to_verify=doc, - question=state["question"], - **extract_core_fields(state), + question=verification_question, + **in_subgraph_extract_core_fields(state), ), ) for doc in documents diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index b02a831004..28eae89717 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -2,10 +2,10 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.core_state import CoreState +from onyx.agent_search.core_state import SubgraphCoreState from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -15,8 +15,9 @@ ## Graph Input State -class ExpandedRetrievalInput(CoreState): +class ExpandedRetrievalInput(SubgraphCoreState): question: str + dummy: str ## Update/Return States @@ -37,7 +38,14 @@ class DocRetrievalUpdate(TypedDict): class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: RetrievalFitStats | None + sub_question_retrieval_stats: Annotated[list[AgentChunkStats | None], add] + + +## Graph Output State + + +class ExpandedRetrievalOutput(TypedDict): + expanded_retrieval_result: ExpandedRetrievalResult ## Graph State @@ -50,17 +58,11 @@ class ExpandedRetrievalState( DocRetrievalUpdate, DocVerificationUpdate, DocRerankingUpdate, + ExpandedRetrievalOutput, ): pass -## Graph Output State - - -class ExpandedRetrievalOutput(TypedDict): - expanded_retrieval_result: ExpandedRetrievalResult - - ## Conditional Input States diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index f04fb097b3..cded24d43f 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -4,7 +4,7 @@ from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.core_state import extract_core_fields +from onyx.agent_search.core_state import extract_core_fields_for_subgraph from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -16,7 +16,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha Send( "answer_query", AnswerQuestionInput( - **extract_core_fields(state), + **extract_core_fields_for_subgraph(state), question=question, ), ) @@ -33,12 +33,13 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: + print("sending to initial retrieval via edge") return [ Send( "initial_retrieval", ExpandedRetrievalInput( question=state["search_request"].query, - **extract_core_fields(state), + **extract_core_fields_for_subgraph(state), ), ) ] diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index cb6d090daa..91f8e30db5 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -3,25 +3,17 @@ from langgraph.graph import StateGraph from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder -from onyx.agent_search.expanded_retrieval.graph_builder import ( - expanded_retrieval_graph_builder, -) from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries -from onyx.agent_search.main.edges import send_to_initial_retrieval from onyx.agent_search.main.nodes.base_decomp import main_decomp_base from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, ) -from onyx.agent_search.main.nodes.generate_initial_BASE_answer import ( - generate_initial_base_answer, -) from onyx.agent_search.main.nodes.ingest_answers import ingest_answers -from onyx.agent_search.main.nodes.ingest_initial_retrieval import ( - ingest_initial_retrieval, -) from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState +test_mode = False + def main_graph_builder(test_mode: bool = False) -> StateGraph: graph = StateGraph( @@ -40,15 +32,27 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: node="answer_query", action=answer_query_subgraph, ) - expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() - graph.add_node( - node="initial_retrieval", - action=expanded_retrieval_subgraph, - ) - graph.add_node( - node="ingest_initial_retrieval", - action=ingest_initial_retrieval, - ) + + # graph.add_node( + # node="prep_for_initial_retrieval", + # action=prep_for_initial_retrieval, + # ) + + # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + # graph.add_node( + # node="initial_retrieval", + # action=expanded_retrieval_subgraph, + # ) + + # base_raw_search_subgraph = base_raw_search_graph_builder().compile() + # graph.add_node( + # node="base_raw_search_data", + # action=base_raw_search_subgraph, + # ) + # graph.add_node( + # node="ingest_initial_retrieval", + # action=ingest_initial_retrieval, + # ) graph.add_node( node="ingest_answers", action=ingest_answers, @@ -57,24 +61,50 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: node="generate_initial_answer", action=generate_initial_answer, ) - if test_mode: - graph.add_node( - node="generate_initial_base_answer", - action=generate_initial_base_answer, - ) + # if test_mode: + # graph.add_node( + # node="generate_initial_base_answer", + # action=generate_initial_base_answer, + # ) ### Add edges ### - graph.add_conditional_edges( - source=START, - path=send_to_initial_retrieval, - path_map=["initial_retrieval"], - ) - graph.add_edge( - start_key="initial_retrieval", - end_key="ingest_initial_retrieval", - ) - + # graph.add_conditional_edges( + # source=START, + # path=send_to_initial_retrieval, + # path_map=["initial_retrieval"], + # ) + + # graph.add_edge( + # start_key=START, + # end_key="prep_for_initial_retrieval", + # ) + # graph.add_edge( + # start_key="prep_for_initial_retrieval", + # end_key="initial_retrieval", + # ) + # graph.add_edge( + # start_key="initial_retrieval", + # end_key="ingest_initial_retrieval", + # ) + + # graph.add_edge( + # start_key=START, + # end_key="base_raw_search_data" + # ) + + # # graph.add_edge( + # # start_key="base_raw_search_data", + # # end_key=END + # # ) + # graph.add_edge( + # start_key="base_raw_search_data", + # end_key="ingest_initial_retrieval", + # ) + # graph.add_edge( + # start_key="ingest_initial_retrieval", + # end_key=END + # ) graph.add_edge( start_key=START, end_key="base_decomp", @@ -90,23 +120,37 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: ) graph.add_edge( - start_key=["ingest_answers", "ingest_initial_retrieval"], + start_key="ingest_answers", end_key="generate_initial_answer", ) - if test_mode: - graph.add_edge( - start_key=["ingest_answers", "ingest_initial_retrieval"], - end_key="generate_initial_base_answer", - ) - graph.add_edge( - start_key=["generate_initial_answer", "generate_initial_base_answer"], - end_key=END, - ) - else: - graph.add_edge( - start_key="generate_initial_answer", - end_key=END, - ) + + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_answer", + # ) + + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + # graph.add_edge( + # start_key="ingest_answers", + # end_key="generate_initial_answer", + # ) + # if test_mode: + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_base_answer", + # ) + # graph.add_edge( + # start_key=["generate_initial_answer", "generate_initial_base_answer"], + # end_key=END, + # ) + # else: + # graph.add_edge( + # start_key="generate_initial_answer", + # end_key=END, + # ) return graph diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 2522eb7b77..cb131115e4 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -159,9 +159,10 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: response = model.invoke(msg) answer = response.pretty_repr() - initial_agent_stats = _calculate_initial_agent_stats( - state["decomp_answer_results"], state["sub_question_retrieval_stats"] - ) + # initial_agent_stats = _calculate_initial_agent_stats( + # state["decomp_answer_results"], state["sub_question_retrieval_stats"] + # ) + initial_agent_stats = None print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py index 3a8403fe9b..513dff1fb3 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -3,6 +3,14 @@ def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrievalUpdate: + sub_question_retrieval_stats = state[ + "expanded_retrieval_result" + ].sub_question_retrieval_stats + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = [] + else: + sub_question_retrieval_stats = [sub_question_retrieval_stats] + return ExpandedRetrievalUpdate( original_question_retrieval_results=state[ "expanded_retrieval_result" @@ -10,7 +18,5 @@ def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrieva all_original_question_documents=state[ "expanded_retrieval_result" ].all_documents, - sub_question_retrieval_stats=state[ - "expanded_retrieval_result" - ].sub_question_retrieval_stats, + sub_question_retrieval_stats=sub_question_retrieval_stats, ) diff --git a/backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py new file mode 100644 index 0000000000..d7732113e6 --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py @@ -0,0 +1,12 @@ +from onyx.agent_search.core_state import extract_core_fields_for_subgraph +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.main.states import MainState + + +def prep_for_initial_retrieval(state: MainState) -> ExpandedRetrievalInput: + print("prepping") + return ExpandedRetrievalInput( + question=state["search_request"].query, + dummy="0", + **extract_core_fields_for_subgraph(state) + ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 2cd8d26831..7a7c4ffcc8 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -4,6 +4,7 @@ from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import CoreState +from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats @@ -40,7 +41,7 @@ class ExpandedRetrievalUpdate(TypedDict): list[InferenceSection], dedup_inference_sections ] original_question_retrieval_results: list[QueryResult] - sub_question_retrieval_stats: AgentChunkStats + sub_question_retrieval_stats: Annotated[list[AgentChunkStats], add] ## Graph Input State @@ -62,7 +63,8 @@ class MainState( DecompAnswersUpdate, ExpandedRetrievalUpdate, ): - pass + # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] + base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add] ## Graph Output State diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index b060cba571..de207af628 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -24,8 +24,22 @@ def _parse_agent_event( Parse the event into a typed object. Return None if we are not interested in the event. """ - if event["name"] == "LangGraph": - return None + # if event["name"] == "LangGraph": + # return None + + event_type = event["event"] + langgraph_node = event["metadata"].get("langgraph_node", "_graph_") + if "input" in event["data"] and isinstance(event["data"]["input"], str): + input_data = f'\nINPUT: {langgraph_node} -- {str(event["data"]["input"])}' + else: + input_data = "" + if "output" in event["data"] and isinstance(event["data"]["output"], str): + output_data = f'\nOUTPUT: {langgraph_node} -- {str(event["data"]["output"])}' + else: + output_data = "" + if len(input_data) > 0 or len(output_data) > 0: + return input_data + output_data + event_type = event["event"] if event_type == "tool_call_kickoff": return ToolCallKickoff(**event["data"]) @@ -102,4 +116,5 @@ def run_graph( query="what can you do with onyx or danswer?", ) for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): - print(output) + print("a") + # print(output) From 568bc16536754162cd1a50cfb9be5904936e8a49 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sun, 29 Dec 2024 10:51:31 -0800 Subject: [PATCH 29/78] all 3 options --- .../onyx/agent_search/main/graph_builder.py | 523 +++++++++++++----- 1 file changed, 398 insertions(+), 125 deletions(-) diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 91f8e30db5..e1b231886e 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -3,15 +3,22 @@ from langgraph.graph import StateGraph from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder +from onyx.agent_search.base_raw_search.graph_builder import ( + base_raw_search_graph_builder, +) from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries from onyx.agent_search.main.nodes.base_decomp import main_decomp_base from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, ) from onyx.agent_search.main.nodes.ingest_answers import ingest_answers +from onyx.agent_search.main.nodes.ingest_initial_retrieval import ( + ingest_initial_retrieval, +) from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState + test_mode = False @@ -21,136 +28,402 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: input=MainInput, ) - ### Add nodes ### + graph_component = "both" + # graph_component = "right" + # graph_component = "left" - graph.add_node( - node="base_decomp", - action=main_decomp_base, - ) - answer_query_subgraph = answer_query_graph_builder().compile() - graph.add_node( - node="answer_query", - action=answer_query_subgraph, - ) + if graph_component == "left": + ### Add nodes ### - # graph.add_node( - # node="prep_for_initial_retrieval", - # action=prep_for_initial_retrieval, - # ) - - # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() - # graph.add_node( - # node="initial_retrieval", - # action=expanded_retrieval_subgraph, - # ) - - # base_raw_search_subgraph = base_raw_search_graph_builder().compile() - # graph.add_node( - # node="base_raw_search_data", - # action=base_raw_search_subgraph, - # ) - # graph.add_node( - # node="ingest_initial_retrieval", - # action=ingest_initial_retrieval, - # ) - graph.add_node( - node="ingest_answers", - action=ingest_answers, - ) - graph.add_node( - node="generate_initial_answer", - action=generate_initial_answer, - ) - # if test_mode: - # graph.add_node( - # node="generate_initial_base_answer", - # action=generate_initial_base_answer, - # ) - - ### Add edges ### - - # graph.add_conditional_edges( - # source=START, - # path=send_to_initial_retrieval, - # path_map=["initial_retrieval"], - # ) - - # graph.add_edge( - # start_key=START, - # end_key="prep_for_initial_retrieval", - # ) - # graph.add_edge( - # start_key="prep_for_initial_retrieval", - # end_key="initial_retrieval", - # ) - # graph.add_edge( - # start_key="initial_retrieval", - # end_key="ingest_initial_retrieval", - # ) - - # graph.add_edge( - # start_key=START, - # end_key="base_raw_search_data" - # ) - - # # graph.add_edge( - # # start_key="base_raw_search_data", - # # end_key=END - # # ) - # graph.add_edge( - # start_key="base_raw_search_data", - # end_key="ingest_initial_retrieval", - # ) - # graph.add_edge( - # start_key="ingest_initial_retrieval", - # end_key=END - # ) - graph.add_edge( - start_key=START, - end_key="base_decomp", - ) - graph.add_conditional_edges( - source="base_decomp", - path=parallelize_decompozed_answer_queries, - path_map=["answer_query"], - ) - graph.add_edge( - start_key="answer_query", - end_key="ingest_answers", - ) + graph.add_node( + node="base_decomp", + action=main_decomp_base, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query", + action=answer_query_subgraph, + ) - graph.add_edge( - start_key="ingest_answers", - end_key="generate_initial_answer", - ) + # graph.add_node( + # node="prep_for_initial_retrieval", + # action=prep_for_initial_retrieval, + # ) - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_answer", - # ) + # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + # graph.add_node( + # node="initial_retrieval", + # action=expanded_retrieval_subgraph, + # ) - graph.add_edge( - start_key="generate_initial_answer", - end_key=END, - ) - # graph.add_edge( - # start_key="ingest_answers", - # end_key="generate_initial_answer", - # ) - # if test_mode: - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_base_answer", - # ) - # graph.add_edge( - # start_key=["generate_initial_answer", "generate_initial_base_answer"], - # end_key=END, - # ) - # else: - # graph.add_edge( - # start_key="generate_initial_answer", - # end_key=END, - # ) + # base_raw_search_subgraph = base_raw_search_graph_builder().compile() + # graph.add_node( + # node="base_raw_search_data", + # action=base_raw_search_subgraph, + # ) + # graph.add_node( + # node="ingest_initial_retrieval", + # action=ingest_initial_retrieval, + # ) + graph.add_node( + node="ingest_answers", + action=ingest_answers, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + # if test_mode: + # graph.add_node( + # node="generate_initial_base_answer", + # action=generate_initial_base_answer, + # ) + + ### Add edges ### + + # graph.add_conditional_edges( + # source=START, + # path=send_to_initial_retrieval, + # path_map=["initial_retrieval"], + # ) + + # graph.add_edge( + # start_key=START, + # end_key="prep_for_initial_retrieval", + # ) + # graph.add_edge( + # start_key="prep_for_initial_retrieval", + # end_key="initial_retrieval", + # ) + # graph.add_edge( + # start_key="initial_retrieval", + # end_key="ingest_initial_retrieval", + # ) + + # graph.add_edge( + # start_key=START, + # end_key="base_raw_search_data" + # ) + + # # graph.add_edge( + # # start_key="base_raw_search_data", + # # end_key=END + # # ) + # graph.add_edge( + # start_key="base_raw_search_data", + # end_key="ingest_initial_retrieval", + # ) + # graph.add_edge( + # start_key="ingest_initial_retrieval", + # end_key=END + # ) + graph.add_edge( + start_key=START, + end_key="base_decomp", + ) + graph.add_conditional_edges( + source="base_decomp", + path=parallelize_decompozed_answer_queries, + path_map=["answer_query"], + ) + graph.add_edge( + start_key="answer_query", + end_key="ingest_answers", + ) + + graph.add_edge( + start_key="ingest_answers", + end_key="generate_initial_answer", + ) + + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_answer", + # ) + + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + # graph.add_edge( + # start_key="ingest_answers", + # end_key="generate_initial_answer", + # ) + # if test_mode: + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_base_answer", + # ) + # graph.add_edge( + # start_key=["generate_initial_answer", "generate_initial_base_answer"], + # end_key=END, + # ) + # else: + # graph.add_edge( + # start_key="generate_initial_answer", + # end_key=END, + # ) + + elif graph_component == "right": + ### Add nodes ### + + # graph.add_node( + # node="base_decomp", + # action=main_decomp_base, + # ) + # answer_query_subgraph = answer_query_graph_builder().compile() + # graph.add_node( + # node="answer_query", + # action=answer_query_subgraph, + # ) + + # graph.add_node( + # node="prep_for_initial_retrieval", + # action=prep_for_initial_retrieval, + # ) + + # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + # graph.add_node( + # node="initial_retrieval", + # action=expanded_retrieval_subgraph, + # ) + + base_raw_search_subgraph = base_raw_search_graph_builder().compile() + graph.add_node( + node="base_raw_search_data", + action=base_raw_search_subgraph, + ) + graph.add_node( + node="ingest_initial_retrieval", + action=ingest_initial_retrieval, + ) + # graph.add_node( + # node="ingest_answers", + # action=ingest_answers, + # ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + # if test_mode: + # graph.add_node( + # node="generate_initial_base_answer", + # action=generate_initial_base_answer, + # ) + + ### Add edges ### + + # graph.add_conditional_edges( + # source=START, + # path=send_to_initial_retrieval, + # path_map=["initial_retrieval"], + # ) + + # graph.add_edge( + # start_key=START, + # end_key="prep_for_initial_retrieval", + # ) + # graph.add_edge( + # start_key="prep_for_initial_retrieval", + # end_key="initial_retrieval", + # ) + # graph.add_edge( + # start_key="initial_retrieval", + # end_key="ingest_initial_retrieval", + # ) + + graph.add_edge(start_key=START, end_key="base_raw_search_data") + + # # graph.add_edge( + # # start_key="base_raw_search_data", + # # end_key=END + # # ) + graph.add_edge( + start_key="base_raw_search_data", + end_key="ingest_initial_retrieval", + ) + # graph.add_edge( + # start_key="ingest_initial_retrieval", + # end_key=END + # ) + # graph.add_edge( + # start_key=START, + # end_key="base_decomp", + # ) + # graph.add_conditional_edges( + # source="base_decomp", + # path=parallelize_decompozed_answer_queries, + # path_map=["answer_query"], + # ) + # graph.add_edge( + # start_key="answer_query", + # end_key="ingest_answers", + # ) + + # graph.add_edge( + # start_key="ingest_answers", + # end_key="generate_initial_answer", + # ) + + graph.add_edge( + start_key="ingest_initial_retrieval", + end_key="generate_initial_answer", + ) + + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_answer", + # ) + + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + # graph.add_edge( + # start_key="ingest_answers", + # end_key="generate_initial_answer", + # ) + # if test_mode: + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_base_answer", + # ) + # graph.add_edge( + # start_key=["generate_initial_answer", "generate_initial_base_answer"], + # end_key=END, + # ) + # else: + # graph.add_edge( + # start_key="generate_initial_answer", + # end_key=END, + # ) + + else: + graph.add_node( + node="base_decomp", + action=main_decomp_base, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query", + action=answer_query_subgraph, + ) + + # graph.add_node( + # node="prep_for_initial_retrieval", + # action=prep_for_initial_retrieval, + # ) + + # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + # graph.add_node( + # node="initial_retrieval", + # action=expanded_retrieval_subgraph, + # ) + + base_raw_search_subgraph = base_raw_search_graph_builder().compile() + graph.add_node( + node="base_raw_search_data", + action=base_raw_search_subgraph, + ) + graph.add_node( + node="ingest_initial_retrieval", + action=ingest_initial_retrieval, + ) + graph.add_node( + node="ingest_answers", + action=ingest_answers, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + # if test_mode: + # graph.add_node( + # node="generate_initial_base_answer", + # action=generate_initial_base_answer, + # ) + + ### Add edges ### + + # graph.add_conditional_edges( + # source=START, + # path=send_to_initial_retrieval, + # path_map=["initial_retrieval"], + # ) + + # graph.add_edge( + # start_key=START, + # end_key="prep_for_initial_retrieval", + # ) + # graph.add_edge( + # start_key="prep_for_initial_retrieval", + # end_key="initial_retrieval", + # ) + # graph.add_edge( + # start_key="initial_retrieval", + # end_key="ingest_initial_retrieval", + # ) + + graph.add_edge(start_key=START, end_key="base_raw_search_data") + + # # graph.add_edge( + # # start_key="base_raw_search_data", + # # end_key=END + # # ) + graph.add_edge( + start_key="base_raw_search_data", + end_key="ingest_initial_retrieval", + ) + # graph.add_edge( + # start_key="ingest_initial_retrieval", + # end_key=END + # ) + graph.add_edge( + start_key=START, + end_key="base_decomp", + ) + graph.add_conditional_edges( + source="base_decomp", + path=parallelize_decompozed_answer_queries, + path_map=["answer_query"], + ) + graph.add_edge( + start_key="answer_query", + end_key="ingest_answers", + ) + + # graph.add_edge( + # start_key="ingest_answers", + # end_key="generate_initial_answer", + # ) + + graph.add_edge( + start_key=["ingest_answers", "ingest_initial_retrieval"], + end_key="generate_initial_answer", + ) + + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + # graph.add_edge( + # start_key="ingest_answers", + # end_key="generate_initial_answer", + # ) + # if test_mode: + # graph.add_edge( + # start_key=["ingest_answers", "ingest_initial_retrieval"], + # end_key="generate_initial_base_answer", + # ) + # graph.add_edge( + # start_key=["generate_initial_answer", "generate_initial_base_answer"], + # end_key=END, + # ) + # else: + # graph.add_edge( + # start_key="generate_initial_answer", + # end_key=END, + # ) return graph From 683978ddb02b6148dec48a486485cda690aebade Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sun, 29 Dec 2024 12:23:39 -0800 Subject: [PATCH 30/78] base functioning --- .../base_raw_search/nodes/format_raw_search_results.py | 5 +++-- .../base_raw_search/nodes/generate_raw_search_data.py | 1 + backend/onyx/agent_search/base_raw_search/states.py | 2 +- backend/onyx/agent_search/expanded_retrieval/states.py | 2 ++ .../main/nodes/ingest_initial_retrieval.py | 10 +++++----- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py index dfd2b47e6b..42e0b45731 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py @@ -5,6 +5,7 @@ def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput: print("format_raw_search_results") return BaseRawSearchOutput( - base_retrieval_results=[state["expanded_retrieval_result"]], - base_search_documents=[], + base_expanded_retrieval_result=state["expanded_retrieval_result"], + # base_retrieval_results=[state["expanded_retrieval_result"]], + # base_search_documents=[], ) diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index a09729a4b1..60153edf2b 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -11,4 +11,5 @@ def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: subgraph_db_session=state["db_session"], question=state["search_request"].query, dummy="7", + base_search=True, ) diff --git a/backend/onyx/agent_search/base_raw_search/states.py b/backend/onyx/agent_search/base_raw_search/states.py index 920bab97a2..fb073454c4 100644 --- a/backend/onyx/agent_search/base_raw_search/states.py +++ b/backend/onyx/agent_search/base_raw_search/states.py @@ -27,7 +27,7 @@ class BaseRawSearchOutput(TypedDict): # base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections] # base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] - expanded_retrieval_result: ExpandedRetrievalResult + base_expanded_retrieval_result: ExpandedRetrievalResult ## Graph State diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 28eae89717..e7e687bd60 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -18,6 +18,7 @@ class ExpandedRetrievalInput(SubgraphCoreState): question: str dummy: str + base_search: bool = False ## Update/Return States @@ -46,6 +47,7 @@ class DocRerankingUpdate(TypedDict): class ExpandedRetrievalOutput(TypedDict): expanded_retrieval_result: ExpandedRetrievalResult + base_expanded_retrieval_result: ExpandedRetrievalResult ## Graph State diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py index 513dff1fb3..e341be34de 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -1,10 +1,10 @@ -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput from onyx.agent_search.main.states import ExpandedRetrievalUpdate -def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrievalUpdate: +def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate: sub_question_retrieval_stats = state[ - "expanded_retrieval_result" + "base_expanded_retrieval_result" ].sub_question_retrieval_stats if sub_question_retrieval_stats is None: sub_question_retrieval_stats = [] @@ -13,10 +13,10 @@ def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrieva return ExpandedRetrievalUpdate( original_question_retrieval_results=state[ - "expanded_retrieval_result" + "base_expanded_retrieval_result" ].expanded_queries_results, all_original_question_documents=state[ - "expanded_retrieval_result" + "base_expanded_retrieval_result" ].all_documents, sub_question_retrieval_stats=sub_question_retrieval_stats, ) From 69894418510b3c1748b3e6fcbdd42f2a9075ca14 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 30 Dec 2024 09:11:45 -0800 Subject: [PATCH 31/78] pre-metrics clean-up --- .../agent_search/main/nodes/generate_initial_answer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index cb131115e4..3b3e8fd055 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -168,9 +168,10 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") - print(initial_agent_stats.original_question) - print(initial_agent_stats.sub_questions) - print(initial_agent_stats.agent_effectiveness) + if initial_agent_stats: + print(initial_agent_stats.original_question) + print(initial_agent_stats.sub_questions) + print(initial_agent_stats.agent_effectiveness) print("\n\n ---INITIAL AGENT ANSWER END---\n\n") return InitialAnswerUpdate( From 6dc81bbb7caca22c01fb2dfe9f54392d73363f31 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 30 Dec 2024 12:23:11 -0800 Subject: [PATCH 32/78] fixed stats --- .../agent_search/answer_question/models.py | 5 ++--- .../answer_question/nodes/format_answer.py | 22 +++++++++---------- .../answer_question/nodes/ingest_retrieval.py | 5 ++--- .../agent_search/expanded_retrieval/models.py | 5 +---- .../expanded_retrieval/nodes/doc_reranking.py | 5 +++-- .../nodes/format_results.py | 6 ++--- .../agent_search/expanded_retrieval/states.py | 4 ++-- .../main/nodes/generate_initial_answer.py | 7 +++--- .../main/nodes/ingest_initial_retrieval.py | 7 +++--- backend/onyx/agent_search/main/states.py | 2 +- 10 files changed, 32 insertions(+), 36 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_question/models.py index 6ee67c9e36..ea9fb8f971 100644 --- a/backend/onyx/agent_search/answer_question/models.py +++ b/backend/onyx/agent_search/answer_question/models.py @@ -1,6 +1,5 @@ from pydantic import BaseModel -from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.context.search.models import InferenceSection @@ -15,6 +14,6 @@ class QuestionAnswerResults(BaseModel): question: str answer: str quality: str - expanded_retrieval_results: list[QueryResult] + # expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] - sub_question_retrieval_stats: list[AgentChunkStats] + sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index e748ac0c3f..06977a0ad9 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -4,15 +4,15 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: - sub_question_retrieval_stats = state["sub_question_retrieval_stats"] - if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - elif isinstance(sub_question_retrieval_stats, list): - sub_question_retrieval_stats = sub_question_retrieval_stats - if isinstance(sub_question_retrieval_stats[0], list): - sub_question_retrieval_stats = sub_question_retrieval_stats[0] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] + # sub_question_retrieval_stats = state["sub_question_retrieval_stats"] + # if sub_question_retrieval_stats is None: + # sub_question_retrieval_stats = [] + # elif isinstance(sub_question_retrieval_stats, list): + # sub_question_retrieval_stats = sub_question_retrieval_stats + # if isinstance(sub_question_retrieval_stats[0], list): + # sub_question_retrieval_stats = sub_question_retrieval_stats[0] + # else: + # sub_question_retrieval_stats = [sub_question_retrieval_stats] return AnswerQuestionOutput( answer_results=[ @@ -20,9 +20,9 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: question=state["question"], quality=state["answer_quality"], answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], + # expanded_retrieval_results=state["expanded_retrieval_results"], documents=state["documents"], - sub_question_retrieval_stats=sub_question_retrieval_stats, + sub_question_retrieval_stats=state["sub_question_retrieval_stats"], ) ], ) diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py index a7efd854fa..fd84ccb382 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -1,5 +1,6 @@ from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: @@ -7,9 +8,7 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate "expanded_retrieval_result" ].sub_question_retrieval_stats if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] + sub_question_retrieval_stats = AgentChunkStats() return RetrievalIngestionUpdate( expanded_retrieval_results=state[ diff --git a/backend/onyx/agent_search/expanded_retrieval/models.py b/backend/onyx/agent_search/expanded_retrieval/models.py index 25ab3b899c..4e8caa3605 100644 --- a/backend/onyx/agent_search/expanded_retrieval/models.py +++ b/backend/onyx/agent_search/expanded_retrieval/models.py @@ -1,6 +1,3 @@ -from operator import add -from typing import Annotated - from pydantic import BaseModel from onyx.agent_search.shared_graph_utils.models import AgentChunkStats @@ -19,4 +16,4 @@ class QueryResult(BaseModel): class ExpandedRetrievalResult(BaseModel): expanded_queries_results: list[QueryResult] all_documents: list[InferenceSection] - sub_question_retrieval_stats: Annotated[list[AgentChunkStats], add] + sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index 0f1e182f9f..6edcbec974 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -1,6 +1,7 @@ from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_STATS from onyx.context.search.pipeline import InferenceSection @@ -34,9 +35,9 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: ] # only get the reranked szections, not the SectionRelevancePiece if AGENT_RERANKING_STATS: - fit_scores = [get_fit_scores(verified_documents, reranked_documents)] + fit_scores = get_fit_scores(verified_documents, reranked_documents) else: - fit_scores = [] + fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) return DocRerankingUpdate( reranked_documents=[ diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index c292c14163..0010cdf839 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -86,9 +86,9 @@ def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: ) if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] + sub_question_retrieval_stats = AgentChunkStats() + # else: + # sub_question_retrieval_stats = [sub_question_retrieval_stats] return ExpandedRetrievalOutput( expanded_retrieval_result=ExpandedRetrievalResult( diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index e7e687bd60..d3de047cd3 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -5,7 +5,7 @@ from onyx.agent_search.core_state import SubgraphCoreState from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -39,7 +39,7 @@ class DocRetrievalUpdate(TypedDict): class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: Annotated[list[AgentChunkStats | None], add] + sub_question_retrieval_stats: RetrievalFitStats | None ## Graph Output State diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 3b3e8fd055..3b53758dbc 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -159,10 +159,9 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: response = model.invoke(msg) answer = response.pretty_repr() - # initial_agent_stats = _calculate_initial_agent_stats( - # state["decomp_answer_results"], state["sub_question_retrieval_stats"] - # ) - initial_agent_stats = None + initial_agent_stats = _calculate_initial_agent_stats( + state["decomp_answer_results"], state["original_question_retrieval_stats"] + ) print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py index e341be34de..13177313a5 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -1,5 +1,6 @@ from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput from onyx.agent_search.main.states import ExpandedRetrievalUpdate +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate: @@ -7,9 +8,9 @@ def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpd "base_expanded_retrieval_result" ].sub_question_retrieval_stats if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] + sub_question_retrieval_stats = AgentChunkStats() else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] + sub_question_retrieval_stats = sub_question_retrieval_stats return ExpandedRetrievalUpdate( original_question_retrieval_results=state[ @@ -18,5 +19,5 @@ def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpd all_original_question_documents=state[ "base_expanded_retrieval_result" ].all_documents, - sub_question_retrieval_stats=sub_question_retrieval_stats, + original_question_retrieval_stats=sub_question_retrieval_stats, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 7a7c4ffcc8..d4bac1c3f5 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -41,7 +41,7 @@ class ExpandedRetrievalUpdate(TypedDict): list[InferenceSection], dedup_inference_sections ] original_question_retrieval_results: list[QueryResult] - sub_question_retrieval_stats: Annotated[list[AgentChunkStats], add] + original_question_retrieval_stats: AgentChunkStats ## Graph Input State From 821b226d25fb4c28a02c4f6409b8f92ca2890294 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 30 Dec 2024 12:38:09 -0800 Subject: [PATCH 33/78] removed test file --- backend/onyx/agent_search/base_lg_tests.py | 66 ---------------------- 1 file changed, 66 deletions(-) delete mode 100644 backend/onyx/agent_search/base_lg_tests.py diff --git a/backend/onyx/agent_search/base_lg_tests.py b/backend/onyx/agent_search/base_lg_tests.py deleted file mode 100644 index e937219d36..0000000000 --- a/backend/onyx/agent_search/base_lg_tests.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import TypedDict - -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - - -# The overall state of the graph (this is the public state shared across nodes) -class OverallState(TypedDict): - a: str - - -# Output from node_1 contains private data that is not part of the overall state -class Node1Output(TypedDict): - private_data: str - - -# Node 2 input only requests the private data available after node_1 -class Node2Input(TypedDict): - private_data: str - - -# The private data is only shared between node_1 and node_2 -def node_1(state: OverallState) -> Node2Input: - output = {"private_data": "set by node_1"} - print(f"Entered node `node_1`:\n\tInput: {state}.\n\tReturned: {output}") - return output - - -def node_2(state: Node2Input) -> OverallState: - output = {"a": "set by node_2"} - print(f"Entered node `node_2`:\n\tInput: {state}.\n\tReturned: {output}") - return output - - -# Node 3 only has access to the overall state (no access to private data from node_1) -def node_3(state: OverallState) -> OverallState: - output = {"a": "set by node_3"} - print(f"Entered node `node_3`:\n\tInput: {state}.\n\tReturned: {output}") - return output - - -# Build the state graph -builder = StateGraph(OverallState) -builder.add_node(node_1) # node_1 is the first node -builder.add_node( - node_2 -) # node_2 is the second node and accepts private data from node_1 -builder.add_node(node_3) # node_3 is the third node and does not see the private data -builder.add_edge(START, "node_1") # Start the graph with node_1 -builder.add_edge("node_1", "node_2") # Pass from node_1 to node_2 -builder.add_edge( - "node_2", "node_3" -) # Pass from node_2 to node_3 (only overall state is shared) -builder.add_edge("node_3", END) # End the graph after node_3 -graph = builder.compile() - -# Invoke the graph with the initial state -response = graph.invoke( - { - "a": "set at start", - } -) - -print() -print(f"Output of graph invocation: {response}") From d68cf98e771770e9e68bacd18da76465d2150ea6 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 30 Dec 2024 12:58:35 -0800 Subject: [PATCH 34/78] nodes in one file plus some mypy fixes --- .../agent_search/answer_question/edges.py | 3 +- .../answer_question/graph_builder.py | 4 - .../answer_question/nodes/format_answer.py | 17 +- .../answer_question/nodes/ingest_retrieval.py | 4 - .../agent_search/answer_question/states.py | 2 +- .../agent_search/expanded_retrieval/edges.py | 4 +- .../expanded_retrieval/graph_builder.py | 22 +-- .../expanded_retrieval/nodes/doc_reranking.py | 46 ----- .../expanded_retrieval/nodes/doc_retrieval.py | 55 ------ .../nodes/doc_verification.py | 41 ---- .../nodes/expand_queries.py | 30 --- .../nodes/format_results.py | 99 ---------- .../nodes/verification_kickoff.py | 33 ---- .../agent_search/expanded_retrieval/states.py | 2 +- backend/onyx/agent_search/main/edges.py | 6 +- .../onyx/agent_search/main/graph_builder.py | 12 +- .../agent_search/main/nodes/base_decomp.py | 33 ---- .../nodes/generate_initial_BASE_answer.py | 33 ---- .../main/nodes/generate_initial_answer.py | 181 ------------------ .../agent_search/main/nodes/ingest_answers.py | 16 -- .../main/nodes/ingest_initial_retrieval.py | 22 --- .../main/nodes/prep_for_initial_retrieval.py | 12 -- backend/onyx/agent_search/main/states.py | 2 +- 23 files changed, 34 insertions(+), 645 deletions(-) delete mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py delete mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py delete mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py delete mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py delete mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py delete mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py delete mode 100644 backend/onyx/agent_search/main/nodes/base_decomp.py delete mode 100644 backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py delete mode 100644 backend/onyx/agent_search/main/nodes/generate_initial_answer.py delete mode 100644 backend/onyx/agent_search/main/nodes/ingest_answers.py delete mode 100644 backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py delete mode 100644 backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 1caeec6e9f..0bb96d8d82 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -15,6 +15,7 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: ExpandedRetrievalInput( **in_subgraph_extract_core_fields(state), question=state["question"], - dummy="1" + dummy="1", + base_search=False ), ) diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index 0aebb045de..e01aa950cb 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -90,10 +90,6 @@ def answer_query_graph_builder() -> StateGraph: ) with get_session_context_manager() as db_session: inputs = AnswerQuestionInput( - search_request=search_request, - primary_llm=primary_llm, - fast_llm=fast_llm, - db_session=db_session, question="what can you do with onyx?", ) for thing in compiled_graph.stream( diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index e748ac0c3f..5bbfd8118f 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -5,14 +5,15 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: sub_question_retrieval_stats = state["sub_question_retrieval_stats"] - if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - elif isinstance(sub_question_retrieval_stats, list): - sub_question_retrieval_stats = sub_question_retrieval_stats - if isinstance(sub_question_retrieval_stats[0], list): - sub_question_retrieval_stats = sub_question_retrieval_stats[0] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] + + # if sub_question_retrieval_stats_raw is None: + # sub_question_retrieval_stats = [] + # elif isinstance(sub_question_retrieval_stats_raw, list): + # sub_question_retrieval_stats = sub_question_retrieval_stats_raw + # if isinstance(sub_question_retrieval_stats[0], list): + # sub_question_retrieval_stats = sub_question_retrieval_stats[0] + # else: + # sub_question_retrieval_stats = [sub_question_retrieval_stats_raw] return AnswerQuestionOutput( answer_results=[ diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py index a7efd854fa..2aa8c6d5ec 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -6,10 +6,6 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate sub_question_retrieval_stats = state[ "expanded_retrieval_result" ].sub_question_retrieval_stats - if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] return RetrievalIngestionUpdate( expanded_retrieval_results=state[ diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 28f4dc2134..98b52ddfc3 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -23,7 +23,7 @@ class QAGenerationUpdate(TypedDict): class RetrievalIngestionUpdate(TypedDict): expanded_retrieval_results: list[QueryResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: AgentChunkStats + sub_question_retrieval_stats: list[AgentChunkStats] ## Graph Input State diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 006d915ead..eaeb6a1115 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -3,7 +3,7 @@ from langgraph.types import Send from onyx.agent_search.core_state import in_subgraph_extract_core_fields -from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput +from onyx.agent_search.expanded_retrieval.nodes import RetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState @@ -18,6 +18,8 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab query_to_retrieve=query, question=question, **in_subgraph_extract_core_fields(state), + dummy="1", + base_search=False, ), ) for query in query_expansions diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 8da14eea43..ed3b80a2ce 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -3,16 +3,12 @@ from langgraph.graph import StateGraph from onyx.agent_search.expanded_retrieval.edges import parallel_retrieval_edge -from onyx.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking -from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval -from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( - doc_verification, -) -from onyx.agent_search.expanded_retrieval.nodes.expand_queries import expand_queries -from onyx.agent_search.expanded_retrieval.nodes.format_results import format_results -from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( - verification_kickoff, -) +from onyx.agent_search.expanded_retrieval.nodes import doc_reranking +from onyx.agent_search.expanded_retrieval.nodes import doc_retrieval +from onyx.agent_search.expanded_retrieval.nodes import doc_verification +from onyx.agent_search.expanded_retrieval.nodes import expand_queries +from onyx.agent_search.expanded_retrieval.nodes import format_results +from onyx.agent_search.expanded_retrieval.nodes import verification_kickoff from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState @@ -97,11 +93,9 @@ def expanded_retrieval_graph_builder() -> StateGraph: ) with get_session_context_manager() as db_session: inputs = ExpandedRetrievalInput( - search_request=search_request, - primary_llm=primary_llm, - fast_llm=fast_llm, - db_session=db_session, question="what can you do with onyx?", + dummy="1", + base_search=False, ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py deleted file mode 100644 index 0f1e182f9f..0000000000 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ /dev/null @@ -1,46 +0,0 @@ -from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores -from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS -from onyx.configs.dev_configs import AGENT_RERANKING_STATS -from onyx.context.search.pipeline import InferenceSection -from onyx.context.search.pipeline import retrieval_preprocessing -from onyx.context.search.pipeline import search_postprocessing -from onyx.context.search.pipeline import SearchRequest - - -def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: - verified_documents = state["verified_documents"] - - # Rerank post retrieval and verification. First, create a search query - # then create the list of reranked sections - - question = state.get("question", state["subgraph_search_request"].query) - _search_query = retrieval_preprocessing( - search_request=SearchRequest(query=question), - user=None, - llm=state["subgraph_fast_llm"], - db_session=state["subgraph_db_session"], - ) - - reranked_documents = list( - search_postprocessing( - search_query=_search_query, - retrieved_sections=verified_documents, - llm=state["subgraph_fast_llm"], - ) - )[ - 0 - ] # only get the reranked szections, not the SectionRelevancePiece - - if AGENT_RERANKING_STATS: - fit_scores = [get_fit_scores(verified_documents, reranked_documents)] - else: - fit_scores = [] - - return DocRerankingUpdate( - reranked_documents=[ - doc for doc in reranked_documents if type(doc) == InferenceSection - ][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS], - sub_question_retrieval_stats=fit_scores, - ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py deleted file mode 100644 index 4e4d44cada..0000000000 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ /dev/null @@ -1,55 +0,0 @@ -from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate -from onyx.agent_search.expanded_retrieval.states import RetrievalInput -from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores -from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS -from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS -from onyx.context.search.models import SearchRequest -from onyx.context.search.pipeline import SearchPipeline - - -def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: - """ - Retrieve documents - - Args: - state (RetrievalInput): Primary state + the query to retrieve - - Updates: - expanded_retrieval_results: list[ExpandedRetrievalResult] - retrieved_documents: list[InferenceSection] - """ - - llm = state["subgraph_primary_llm"] - fast_llm = state["subgraph_fast_llm"] - query_to_retrieve = state["query_to_retrieve"] - - search_results = SearchPipeline( - search_request=SearchRequest( - query=query_to_retrieve, - ), - user=None, - llm=llm, - fast_llm=fast_llm, - db_session=state["subgraph_db_session"], - ) - - retrieved_docs = search_results._get_sections()[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] - - if AGENT_RETRIEVAL_STATS: - fit_scores = get_fit_scores( - retrieved_docs, - search_results.reranked_sections[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS], - ) - else: - fit_scores = None - - expanded_retrieval_result = QueryResult( - query=query_to_retrieve, - search_results=retrieved_docs, - stats=fit_scores, - ) - return DocRetrievalUpdate( - expanded_retrieval_results=[expanded_retrieval_result], - retrieved_documents=retrieved_docs, - ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py deleted file mode 100644 index 11574a9fe5..0000000000 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ /dev/null @@ -1,41 +0,0 @@ -from langchain_core.messages import HumanMessage - -from onyx.agent_search.expanded_retrieval.states import DocVerificationInput -from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate -from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT - - -def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: - """ - Check whether the document is relevant for the original user question - - Args: - state (DocVerificationInput): The current state - - Updates: - verified_documents: list[InferenceSection] - """ - - question = state["question"] - doc_to_verify = state["doc_to_verify"] - document_content = doc_to_verify.combined_content - - msg = [ - HumanMessage( - content=VERIFIER_PROMPT.format( - question=question, document_content=document_content - ) - ) - ] - - fast_llm = state["subgraph_fast_llm"] - - response = fast_llm.invoke(msg) - - verified_documents = [] - if "yes" in response.content.lower(): - verified_documents.append(doc_to_verify) - - return DocVerificationUpdate( - verified_documents=verified_documents, - ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py deleted file mode 100644 index 87599ce726..0000000000 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py +++ /dev/null @@ -1,30 +0,0 @@ -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate -from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL -from onyx.llm.interfaces import LLM - - -def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: - question = state.get("question") - llm: LLM = state["subgraph_fast_llm"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), - ) - ] - llm_response_list = list( - llm.stream( - prompt=msg, - ) - ) - llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - - rewritten_queries = llm_response.split("--") - - return QueryExpansionUpdate( - expanded_queries=rewritten_queries, - ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py deleted file mode 100644 index c292c14163..0000000000 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ /dev/null @@ -1,99 +0,0 @@ -from collections import defaultdict - -import numpy as np - -from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult -from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -from onyx.agent_search.expanded_retrieval.states import InferenceSection -from onyx.agent_search.shared_graph_utils.models import AgentChunkStats - - -def _calculate_sub_question_retrieval_stats( - verified_documents: list[InferenceSection], - expanded_retrieval_results: list[QueryResult], -) -> AgentChunkStats: - chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict( - lambda: defaultdict(list) - ) - - for expanded_retrieval_result in expanded_retrieval_results: - for doc in expanded_retrieval_result.search_results: - doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" - if doc.center_chunk.score is not None: - chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) - - verified_doc_chunk_ids = [ - f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}" - for verified_document in verified_documents - ] - dismissed_doc_chunk_ids = [] - - raw_chunk_stats_counts: dict[str, int] = defaultdict(int) - raw_chunk_stats_scores: dict[str, float] = defaultdict(float) - for doc_chunk_id, chunk_data in chunk_scores.items(): - if doc_chunk_id in verified_doc_chunk_ids: - raw_chunk_stats_counts["verified_count"] += 1 - - valid_chunk_scores = [ - score for score in chunk_data["score"] if score is not None - ] - raw_chunk_stats_scores["verified_scores"] += float( - np.mean(valid_chunk_scores) - ) - else: - raw_chunk_stats_counts["rejected_count"] += 1 - valid_chunk_scores = [ - score for score in chunk_data["score"] if score is not None - ] - raw_chunk_stats_scores["rejected_scores"] += float( - np.mean(valid_chunk_scores) - ) - dismissed_doc_chunk_ids.append(doc_chunk_id) - - if raw_chunk_stats_counts["verified_count"] == 0: - verified_avg_scores = 0.0 - else: - verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float( - raw_chunk_stats_counts["verified_count"] - ) - - rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None) - if rejected_scores is not None: - rejected_avg_scores = rejected_scores / float( - raw_chunk_stats_counts["rejected_count"] - ) - else: - rejected_avg_scores = None - - chunk_stats = AgentChunkStats( - verified_count=raw_chunk_stats_counts["verified_count"], - verified_avg_scores=verified_avg_scores, - rejected_count=raw_chunk_stats_counts["rejected_count"], - rejected_avg_scores=rejected_avg_scores, - verified_doc_chunk_ids=verified_doc_chunk_ids, - dismissed_doc_chunk_ids=dismissed_doc_chunk_ids, - ) - - return chunk_stats - - -def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: - sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats( - verified_documents=state["verified_documents"], - expanded_retrieval_results=state["expanded_retrieval_results"], - ) - - if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] - - return ExpandedRetrievalOutput( - expanded_retrieval_result=ExpandedRetrievalResult( - expanded_queries_results=state["expanded_retrieval_results"], - all_documents=state["reranked_documents"], - sub_question_retrieval_stats=sub_question_retrieval_stats, - ), - ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py deleted file mode 100644 index 00adf495a3..0000000000 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Literal - -from langgraph.types import Command -from langgraph.types import Send - -from onyx.agent_search.core_state import in_subgraph_extract_core_fields -from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( - DocVerificationInput, -) -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState - - -def verification_kickoff( - state: ExpandedRetrievalState, -) -> Command[Literal["doc_verification"]]: - documents = state["retrieved_documents"] - verification_question = state.get( - "question", state["subgraph_search_request"].query - ) - return Command( - update={}, - goto=[ - Send( - node="doc_verification", - arg=DocVerificationInput( - doc_to_verify=doc, - question=verification_question, - **in_subgraph_extract_core_fields(state), - ), - ) - for doc in documents - ], - ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index e7e687bd60..4fe348011c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -18,7 +18,7 @@ class ExpandedRetrievalInput(SubgraphCoreState): question: str dummy: str - base_search: bool = False + base_search: bool ## Update/Return States diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index cded24d43f..3870126ac9 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -27,7 +27,9 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha return [ Send( "ingest_answers", - AnswerQuestionOutput(), + AnswerQuestionOutput( + answer_results=[], + ), ) ] @@ -40,6 +42,8 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: ExpandedRetrievalInput( question=state["search_request"].query, **extract_core_fields_for_subgraph(state), + dummy="retrieval", + base_search=False, ), ) ] diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index e1b231886e..50a839cb21 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -7,14 +7,10 @@ base_raw_search_graph_builder, ) from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries -from onyx.agent_search.main.nodes.base_decomp import main_decomp_base -from onyx.agent_search.main.nodes.generate_initial_answer import ( - generate_initial_answer, -) -from onyx.agent_search.main.nodes.ingest_answers import ingest_answers -from onyx.agent_search.main.nodes.ingest_initial_retrieval import ( - ingest_initial_retrieval, -) +from onyx.agent_search.main.nodes import generate_initial_answer +from onyx.agent_search.main.nodes import ingest_answers +from onyx.agent_search.main.nodes import ingest_initial_retrieval +from onyx.agent_search.main.nodes import main_decomp_base from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py deleted file mode 100644 index 8285680cb9..0000000000 --- a/backend/onyx/agent_search/main/nodes/base_decomp.py +++ /dev/null @@ -1,33 +0,0 @@ -from langchain_core.messages import HumanMessage - -from onyx.agent_search.main.states import BaseDecompUpdate -from onyx.agent_search.main.states import MainState -from onyx.agent_search.shared_graph_utils.prompts import ( - INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, -) -from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string - - -def main_decomp_base(state: MainState) -> BaseDecompUpdate: - question = state["search_request"].query - - msg = [ - HumanMessage( - content=INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(question=question), - ) - ] - - # Get the rewritten queries in a defined format - model = state["fast_llm"] - response = model.invoke(msg) - - content = response.pretty_repr() - list_of_subquestions = clean_and_parse_list_string(content) - - decomp_list: list[str] = [ - sub_question["sub_question"].strip() for sub_question in list_of_subquestions - ] - - return BaseDecompUpdate( - initial_decomp_questions=decomp_list, - ) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py deleted file mode 100644 index 00bc742f06..0000000000 --- a/backend/onyx/agent_search/main/nodes/generate_initial_BASE_answer.py +++ /dev/null @@ -1,33 +0,0 @@ -from langchain_core.messages import HumanMessage - -from onyx.agent_search.main.states import InitialAnswerBASEUpdate -from onyx.agent_search.main.states import MainState -from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT -from onyx.agent_search.shared_graph_utils.utils import format_docs - - -def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: - print("---GENERATE INITIAL BASE ANSWER---") - - question = state["search_request"].query - original_question_docs = state["all_original_question_documents"] - - msg = [ - HumanMessage( - content=INITIAL_RAG_BASE_PROMPT.format( - question=question, - context=format_docs(original_question_docs), - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - answer = response.pretty_repr() - - print() - print( - f"\n\n---INITIAL BASE ANSWER START---\n\nBase: {answer}\n\n ---INITIAL BASE ANSWER END---\n\n" - ) - return InitialAnswerBASEUpdate(initial_base_answer=answer) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py deleted file mode 100644 index 3b3e8fd055..0000000000 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ /dev/null @@ -1,181 +0,0 @@ -from langchain_core.messages import HumanMessage - -from onyx.agent_search.answer_question.states import QuestionAnswerResults -from onyx.agent_search.main.states import InitialAnswerUpdate -from onyx.agent_search.main.states import MainState -from onyx.agent_search.shared_graph_utils.models import AgentChunkStats -from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT -from onyx.agent_search.shared_graph_utils.prompts import ( - INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, -) -from onyx.agent_search.shared_graph_utils.utils import format_docs - - -def _calculate_initial_agent_stats( - decomp_answer_results: list[QuestionAnswerResults], - original_question_stats: AgentChunkStats, -) -> InitialAgentResultStats: - initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats( - sub_questions={}, - original_question={}, - agent_effectiveness={}, - ) - - orig_verified = original_question_stats.verified_count - orig_support_score = original_question_stats.verified_avg_scores - - verified_document_chunk_ids = [] - support_scores = 0.0 - - for decomp_answer_result in decomp_answer_results: - verified_document_chunk_ids += ( - decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids - ) - if ( - decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores - is not None - ): - support_scores += ( - decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores - ) - - verified_document_chunk_ids = list(set(verified_document_chunk_ids)) - - # Calculate sub-question stats - if ( - verified_document_chunk_ids - and len(verified_document_chunk_ids) > 0 - and support_scores is not None - ): - sub_question_stats: dict[str, float | int | None] = { - "num_verified_documents": len(verified_document_chunk_ids), - "verified_avg_score": float(support_scores / len(decomp_answer_results)), - } - else: - sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None} - - initial_agent_result_stats.sub_questions.update(sub_question_stats) - - # Get original question stats - initial_agent_result_stats.original_question.update( - { - "num_verified_documents": original_question_stats.verified_count, - "verified_avg_score": original_question_stats.verified_avg_scores, - } - ) - - # Calculate chunk utilization ratio - sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"] - - chunk_ratio: float | None = None - if sub_verified is not None and orig_verified is not None and orig_verified > 0: - chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0 - elif sub_verified is not None and sub_verified > 0: - chunk_ratio = 10.0 - - initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio - - if ( - orig_support_score is None - and initial_agent_result_stats.sub_questions["verified_avg_score"] is None - ): - initial_agent_result_stats.agent_effectiveness["support_ratio"] = None - elif orig_support_score is None: - initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10 - elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None: - initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0 - else: - initial_agent_result_stats.agent_effectiveness["support_ratio"] = ( - initial_agent_result_stats.sub_questions["verified_avg_score"] - / orig_support_score - ) - - return initial_agent_result_stats - - -def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: - print("---GENERATE INITIAL---") - - question = state["search_request"].query - sub_question_docs = state["documents"] - all_original_question_documents = state["all_original_question_documents"] - relevant_docs = dedup_inference_sections( - sub_question_docs, all_original_question_documents - ) - - net_new_original_question_docs = [] - for all_original_question_doc in all_original_question_documents: - if all_original_question_doc not in sub_question_docs: - net_new_original_question_docs.append(all_original_question_doc) - - decomp_answer_results = state["decomp_answer_results"] - - good_qa_list: list[str] = [] - decomp_questions = [] - - _SUB_QUESTION_ANSWER_TEMPLATE = """ - Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n - """ - for decomp_answer_result in decomp_answer_results: - decomp_questions.append(decomp_answer_result.question) - if ( - decomp_answer_result.quality.lower().startswith("yes") - and len(decomp_answer_result.answer) > 0 - and decomp_answer_result.answer != "I don't know" - ): - good_qa_list.append( - _SUB_QUESTION_ANSWER_TEMPLATE.format( - sub_question=decomp_answer_result.question, - sub_answer=decomp_answer_result.answer, - ) - ) - - sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) - - if len(good_qa_list) > 0: - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT.format( - question=question, - answered_sub_questions=sub_question_answer_str, - relevant_docs=format_docs(relevant_docs), - ) - ) - ] - else: - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format( - question=question, - relevant_docs=format_docs(relevant_docs), - ) - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - answer = response.pretty_repr() - - # initial_agent_stats = _calculate_initial_agent_stats( - # state["decomp_answer_results"], state["sub_question_retrieval_stats"] - # ) - initial_agent_stats = None - - print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") - - print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") - - if initial_agent_stats: - print(initial_agent_stats.original_question) - print(initial_agent_stats.sub_questions) - print(initial_agent_stats.agent_effectiveness) - print("\n\n ---INITIAL AGENT ANSWER END---\n\n") - - return InitialAnswerUpdate( - initial_answer=answer, - initial_agent_stats=initial_agent_stats, - generated_sub_questions=decomp_questions, - ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py deleted file mode 100644 index 5eac7670e9..0000000000 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ /dev/null @@ -1,16 +0,0 @@ -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.main.states import DecompAnswersUpdate -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections - - -def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: - documents = [] - answer_results = state.get("answer_results", []) - for answer_result in answer_results: - documents.extend(answer_result.documents) - return DecompAnswersUpdate( - # Deduping is done by the documents operator for the main graph - # so we might not need to dedup here - documents=dedup_inference_sections(documents, []), - decomp_answer_results=answer_results, - ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py deleted file mode 100644 index e341be34de..0000000000 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ /dev/null @@ -1,22 +0,0 @@ -from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput -from onyx.agent_search.main.states import ExpandedRetrievalUpdate - - -def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate: - sub_question_retrieval_stats = state[ - "base_expanded_retrieval_result" - ].sub_question_retrieval_stats - if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [] - else: - sub_question_retrieval_stats = [sub_question_retrieval_stats] - - return ExpandedRetrievalUpdate( - original_question_retrieval_results=state[ - "base_expanded_retrieval_result" - ].expanded_queries_results, - all_original_question_documents=state[ - "base_expanded_retrieval_result" - ].all_documents, - sub_question_retrieval_stats=sub_question_retrieval_stats, - ) diff --git a/backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py deleted file mode 100644 index d7732113e6..0000000000 --- a/backend/onyx/agent_search/main/nodes/prep_for_initial_retrieval.py +++ /dev/null @@ -1,12 +0,0 @@ -from onyx.agent_search.core_state import extract_core_fields_for_subgraph -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.main.states import MainState - - -def prep_for_initial_retrieval(state: MainState) -> ExpandedRetrievalInput: - print("prepping") - return ExpandedRetrievalInput( - question=state["search_request"].query, - dummy="0", - **extract_core_fields_for_subgraph(state) - ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 7a7c4ffcc8..da3346ec4f 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -27,7 +27,7 @@ class InitialAnswerBASEUpdate(TypedDict): class InitialAnswerUpdate(TypedDict): initial_answer: str - initial_agent_stats: InitialAgentResultStats + initial_agent_stats: InitialAgentResultStats | None generated_sub_questions: list[str] From 11db9647f38738faa52cee0c279fba2872b256f3 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 30 Dec 2024 13:36:05 -0800 Subject: [PATCH 35/78] no more dummy plus nodes files --- .../agent_search/answer_question/edges.py | 1 - .../nodes/generate_raw_search_data.py | 1 - backend/onyx/agent_search/core_state.py | 1 - .../agent_search/expanded_retrieval/edges.py | 1 - .../expanded_retrieval/graph_builder.py | 1 - .../agent_search/expanded_retrieval/nodes.py | 291 ++++++++++++++++++ .../agent_search/expanded_retrieval/states.py | 1 - backend/onyx/agent_search/main/edges.py | 1 - backend/onyx/agent_search/main/nodes.py | 276 +++++++++++++++++ 9 files changed, 567 insertions(+), 7 deletions(-) create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes.py create mode 100644 backend/onyx/agent_search/main/nodes.py diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 0bb96d8d82..badfc02f24 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -15,7 +15,6 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: ExpandedRetrievalInput( **in_subgraph_extract_core_fields(state), question=state["question"], - dummy="1", base_search=False ), ) diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index 60153edf2b..9b6a7b1485 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -10,6 +10,5 @@ def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: subgraph_fast_llm=state["fast_llm"], subgraph_db_session=state["db_session"], question=state["search_request"].query, - dummy="7", base_search=True, ) diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index 6dd8d0f8b8..7f0a34b07f 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -21,7 +21,6 @@ class CoreState(TypedDict, total=False): # is fine if we are only reading db_session: Session log_messages: Annotated[list[str], add] - dummy: str class SubgraphCoreState(TypedDict, total=False): diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index eaeb6a1115..eb9d1bc2b5 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -18,7 +18,6 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab query_to_retrieve=query, question=question, **in_subgraph_extract_core_fields(state), - dummy="1", base_search=False, ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index ed3b80a2ce..35de29a2f9 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -94,7 +94,6 @@ def expanded_retrieval_graph_builder() -> StateGraph: with get_session_context_manager() as db_session: inputs = ExpandedRetrievalInput( question="what can you do with onyx?", - dummy="1", base_search=False, ) for thing in compiled_graph.stream( diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py new file mode 100644 index 0000000000..831e5a9433 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -0,0 +1,291 @@ +from collections import defaultdict +from typing import Literal + +import numpy as np +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langgraph.types import Command +from langgraph.types import Send + +from onyx.agent_search.core_state import in_subgraph_extract_core_fields +from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate +from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate +from onyx.agent_search.expanded_retrieval.states import DocVerificationInput +from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import InferenceSection +from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate +from onyx.agent_search.expanded_retrieval.states import RetrievalInput +from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL +from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS +from onyx.configs.dev_configs import AGENT_RERANKING_STATS +from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS +from onyx.context.search.models import SearchRequest +from onyx.context.search.pipeline import retrieval_preprocessing +from onyx.context.search.pipeline import search_postprocessing +from onyx.context.search.pipeline import SearchPipeline +from onyx.llm.interfaces import LLM + + +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: + verified_documents = state["verified_documents"] + + # Rerank post retrieval and verification. First, create a search query + # then create the list of reranked sections + + question = state.get("question", state["subgraph_search_request"].query) + _search_query = retrieval_preprocessing( + search_request=SearchRequest(query=question), + user=None, + llm=state["subgraph_fast_llm"], + db_session=state["subgraph_db_session"], + ) + + reranked_documents = list( + search_postprocessing( + search_query=_search_query, + retrieved_sections=verified_documents, + llm=state["subgraph_fast_llm"], + ) + )[ + 0 + ] # only get the reranked szections, not the SectionRelevancePiece + + if AGENT_RERANKING_STATS: + fit_scores = get_fit_scores(verified_documents, reranked_documents) + else: + fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) + + return DocRerankingUpdate( + reranked_documents=[ + doc for doc in reranked_documents if type(doc) == InferenceSection + ][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS], + sub_question_retrieval_stats=fit_scores, + ) + + +def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: + """ + Retrieve documents + + Args: + state (RetrievalInput): Primary state + the query to retrieve + + Updates: + expanded_retrieval_results: list[ExpandedRetrievalResult] + retrieved_documents: list[InferenceSection] + """ + + llm = state["subgraph_primary_llm"] + fast_llm = state["subgraph_fast_llm"] + query_to_retrieve = state["query_to_retrieve"] + + search_results = SearchPipeline( + search_request=SearchRequest( + query=query_to_retrieve, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=state["subgraph_db_session"], + ) + + retrieved_docs = search_results._get_sections()[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] + + if AGENT_RETRIEVAL_STATS: + fit_scores = get_fit_scores( + retrieved_docs, + search_results.reranked_sections[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS], + ) + else: + fit_scores = None + + expanded_retrieval_result = QueryResult( + query=query_to_retrieve, + search_results=retrieved_docs, + stats=fit_scores, + ) + return DocRetrievalUpdate( + expanded_retrieval_results=[expanded_retrieval_result], + retrieved_documents=retrieved_docs, + ) + + +def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: + """ + Check whether the document is relevant for the original user question + + Args: + state (DocVerificationInput): The current state + + Updates: + verified_documents: list[InferenceSection] + """ + + question = state["question"] + doc_to_verify = state["doc_to_verify"] + document_content = doc_to_verify.combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=question, document_content=document_content + ) + ) + ] + + fast_llm = state["subgraph_fast_llm"] + + response = fast_llm.invoke(msg) + + verified_documents = [] + if isinstance(response.content, str) and "yes" in response.content.lower(): + verified_documents.append(doc_to_verify) + + return DocVerificationUpdate( + verified_documents=verified_documents, + ) + + +def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: + question = state.get("question") + llm: LLM = state["subgraph_fast_llm"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), + ) + ] + llm_response_list = list( + llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + rewritten_queries = llm_response.split("--") + + return QueryExpansionUpdate( + expanded_queries=rewritten_queries, + ) + + +def _calculate_sub_question_retrieval_stats( + verified_documents: list[InferenceSection], + expanded_retrieval_results: list[QueryResult], +) -> AgentChunkStats: + chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict( + lambda: defaultdict(list) + ) + + for expanded_retrieval_result in expanded_retrieval_results: + for doc in expanded_retrieval_result.search_results: + doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}" + if doc.center_chunk.score is not None: + chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score) + + verified_doc_chunk_ids = [ + f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}" + for verified_document in verified_documents + ] + dismissed_doc_chunk_ids = [] + + raw_chunk_stats_counts: dict[str, int] = defaultdict(int) + raw_chunk_stats_scores: dict[str, float] = defaultdict(float) + for doc_chunk_id, chunk_data in chunk_scores.items(): + if doc_chunk_id in verified_doc_chunk_ids: + raw_chunk_stats_counts["verified_count"] += 1 + + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + raw_chunk_stats_scores["verified_scores"] += float( + np.mean(valid_chunk_scores) + ) + else: + raw_chunk_stats_counts["rejected_count"] += 1 + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + raw_chunk_stats_scores["rejected_scores"] += float( + np.mean(valid_chunk_scores) + ) + dismissed_doc_chunk_ids.append(doc_chunk_id) + + if raw_chunk_stats_counts["verified_count"] == 0: + verified_avg_scores = 0.0 + else: + verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float( + raw_chunk_stats_counts["verified_count"] + ) + + rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None) + if rejected_scores is not None: + rejected_avg_scores = rejected_scores / float( + raw_chunk_stats_counts["rejected_count"] + ) + else: + rejected_avg_scores = None + + chunk_stats = AgentChunkStats( + verified_count=raw_chunk_stats_counts["verified_count"], + verified_avg_scores=verified_avg_scores, + rejected_count=raw_chunk_stats_counts["rejected_count"], + rejected_avg_scores=rejected_avg_scores, + verified_doc_chunk_ids=verified_doc_chunk_ids, + dismissed_doc_chunk_ids=dismissed_doc_chunk_ids, + ) + + return chunk_stats + + +def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: + sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats( + verified_documents=state["verified_documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + ) + + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = AgentChunkStats() + # else: + # sub_question_retrieval_stats = [sub_question_retrieval_stats] + + return ExpandedRetrievalOutput( + expanded_retrieval_result=ExpandedRetrievalResult( + expanded_queries_results=state["expanded_retrieval_results"], + all_documents=state["reranked_documents"], + sub_question_retrieval_stats=sub_question_retrieval_stats, + ), + ) + + +def verification_kickoff( + state: ExpandedRetrievalState, +) -> Command[Literal["doc_verification"]]: + documents = state["retrieved_documents"] + verification_question = state.get( + "question", state["subgraph_search_request"].query + ) + return Command( + update={}, + goto=[ + Send( + node="doc_verification", + arg=DocVerificationInput( + doc_to_verify=doc, + question=verification_question, + **in_subgraph_extract_core_fields(state), + ), + ) + for doc in documents + ], + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index ada8d9b299..b6059f8ff3 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -17,7 +17,6 @@ class ExpandedRetrievalInput(SubgraphCoreState): question: str - dummy: str base_search: bool diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 3870126ac9..22cba21933 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -42,7 +42,6 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: ExpandedRetrievalInput( question=state["search_request"].query, **extract_core_fields_for_subgraph(state), - dummy="retrieval", base_search=False, ), ) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py new file mode 100644 index 0000000000..5fba915c56 --- /dev/null +++ b/backend/onyx/agent_search/main/nodes.py @@ -0,0 +1,276 @@ +from langchain_core.messages import HumanMessage + +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import QuestionAnswerResults +from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.main.states import BaseDecompUpdate +from onyx.agent_search.main.states import DecompAnswersUpdate +from onyx.agent_search.main.states import ExpandedRetrievalUpdate +from onyx.agent_search.main.states import InitialAnswerBASEUpdate +from onyx.agent_search.main.states import InitialAnswerUpdate +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.agent_search.shared_graph_utils.prompts import ( + INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, +) +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import ( + INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, +) +from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def main_decomp_base(state: MainState) -> BaseDecompUpdate: + question = state["search_request"].query + + msg = [ + HumanMessage( + content=INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + model = state["fast_llm"] + response = model.invoke(msg) + + content = response.pretty_repr() + list_of_subquestions = clean_and_parse_list_string(content) + + decomp_list: list[str] = [ + sub_question["sub_question"].strip() for sub_question in list_of_subquestions + ] + + return BaseDecompUpdate( + initial_decomp_questions=decomp_list, + ) + + +def _calculate_initial_agent_stats( + decomp_answer_results: list[QuestionAnswerResults], + original_question_stats: AgentChunkStats, +) -> InitialAgentResultStats: + initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats( + sub_questions={}, + original_question={}, + agent_effectiveness={}, + ) + + orig_verified = original_question_stats.verified_count + orig_support_score = original_question_stats.verified_avg_scores + + verified_document_chunk_ids = [] + support_scores = 0.0 + + for decomp_answer_result in decomp_answer_results: + verified_document_chunk_ids += ( + decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids + ) + if ( + decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores + is not None + ): + support_scores += ( + decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores + ) + + verified_document_chunk_ids = list(set(verified_document_chunk_ids)) + + # Calculate sub-question stats + if ( + verified_document_chunk_ids + and len(verified_document_chunk_ids) > 0 + and support_scores is not None + ): + sub_question_stats: dict[str, float | int | None] = { + "num_verified_documents": len(verified_document_chunk_ids), + "verified_avg_score": float(support_scores / len(decomp_answer_results)), + } + else: + sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None} + + initial_agent_result_stats.sub_questions.update(sub_question_stats) + + # Get original question stats + initial_agent_result_stats.original_question.update( + { + "num_verified_documents": original_question_stats.verified_count, + "verified_avg_score": original_question_stats.verified_avg_scores, + } + ) + + # Calculate chunk utilization ratio + sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"] + + chunk_ratio: float | None = None + if sub_verified is not None and orig_verified is not None and orig_verified > 0: + chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0 + elif sub_verified is not None and sub_verified > 0: + chunk_ratio = 10.0 + + initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio + + if ( + orig_support_score is None + and initial_agent_result_stats.sub_questions["verified_avg_score"] is None + ): + initial_agent_result_stats.agent_effectiveness["support_ratio"] = None + elif orig_support_score is None: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10 + elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0 + else: + initial_agent_result_stats.agent_effectiveness["support_ratio"] = ( + initial_agent_result_stats.sub_questions["verified_avg_score"] + / orig_support_score + ) + + return initial_agent_result_stats + + +def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: + print("---GENERATE INITIAL---") + + question = state["search_request"].query + sub_question_docs = state["documents"] + all_original_question_documents = state["all_original_question_documents"] + relevant_docs = dedup_inference_sections( + sub_question_docs, all_original_question_documents + ) + + net_new_original_question_docs = [] + for all_original_question_doc in all_original_question_documents: + if all_original_question_doc not in sub_question_docs: + net_new_original_question_docs.append(all_original_question_doc) + + decomp_answer_results = state["decomp_answer_results"] + + good_qa_list: list[str] = [] + decomp_questions = [] + + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + for decomp_answer_result in decomp_answer_results: + decomp_questions.append(decomp_answer_result.question) + if ( + decomp_answer_result.quality.lower().startswith("yes") + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != "I don't know" + ): + good_qa_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.question, + sub_answer=decomp_answer_result.answer, + ) + ) + + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + + if len(good_qa_list) > 0: + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + answered_sub_questions=sub_question_answer_str, + relevant_docs=format_docs(relevant_docs), + ) + ) + ] + else: + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format( + question=question, + relevant_docs=format_docs(relevant_docs), + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + answer = response.pretty_repr() + + initial_agent_stats = _calculate_initial_agent_stats( + state["decomp_answer_results"], state["original_question_retrieval_stats"] + ) + + print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + + print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") + + if initial_agent_stats: + print(initial_agent_stats.original_question) + print(initial_agent_stats.sub_questions) + print(initial_agent_stats.agent_effectiveness) + print("\n\n ---INITIAL AGENT ANSWER END---\n\n") + + return InitialAnswerUpdate( + initial_answer=answer, + initial_agent_stats=initial_agent_stats, + generated_sub_questions=decomp_questions, + ) + + +def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: + print("---GENERATE INITIAL BASE ANSWER---") + + question = state["search_request"].query + original_question_docs = state["all_original_question_documents"] + + msg = [ + HumanMessage( + content=INITIAL_RAG_BASE_PROMPT.format( + question=question, + context=format_docs(original_question_docs), + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + answer = response.pretty_repr() + + print() + print( + f"\n\n---INITIAL BASE ANSWER START---\n\nBase: {answer}\n\n ---INITIAL BASE ANSWER END---\n\n" + ) + return InitialAnswerBASEUpdate(initial_base_answer=answer) + + +def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: + documents = [] + answer_results = state.get("answer_results", []) + for answer_result in answer_results: + documents.extend(answer_result.documents) + return DecompAnswersUpdate( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + documents=dedup_inference_sections(documents, []), + decomp_answer_results=answer_results, + ) + + +def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate: + sub_question_retrieval_stats = state[ + "base_expanded_retrieval_result" + ].sub_question_retrieval_stats + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = AgentChunkStats() + else: + sub_question_retrieval_stats = sub_question_retrieval_stats + + return ExpandedRetrievalUpdate( + original_question_retrieval_results=state[ + "base_expanded_retrieval_result" + ].expanded_queries_results, + all_original_question_documents=state[ + "base_expanded_retrieval_result" + ].all_documents, + original_question_retrieval_stats=sub_question_retrieval_stats, + ) From d7812ee807ae897dc0613399a2d503ab823e324c Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 30 Dec 2024 14:01:03 -0800 Subject: [PATCH 36/78] my happy about graph building --- backend/onyx/agent_search/answer_question/states.py | 2 +- backend/onyx/agent_search/core_state.py | 2 +- backend/onyx/agent_search/expanded_retrieval/nodes.py | 7 ++++--- backend/onyx/agent_search/expanded_retrieval/states.py | 4 ++++ 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 98b52ddfc3..28f4dc2134 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -23,7 +23,7 @@ class QAGenerationUpdate(TypedDict): class RetrievalIngestionUpdate(TypedDict): expanded_retrieval_results: list[QueryResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: list[AgentChunkStats] + sub_question_retrieval_stats: AgentChunkStats ## Graph Input State diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index 7f0a34b07f..a035d25b31 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -57,4 +57,4 @@ def in_subgraph_extract_core_fields(state: T_SUBGRAPH) -> SubgraphCoreState: filtered_dict = { k: v for k, v in state.items() if k in SubgraphCoreState.__annotations__ } - return SubgraphCoreState(**dict(filtered_dict)) + return SubgraphCoreState(**dict(filtered_dict)) # type: ignore diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 831e5a9433..50973a20bc 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -15,8 +15,8 @@ from onyx.agent_search.expanded_retrieval.states import DocVerificationInput from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalUpdate from onyx.agent_search.expanded_retrieval.states import InferenceSection from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate from onyx.agent_search.expanded_retrieval.states import RetrievalInput @@ -248,7 +248,7 @@ def _calculate_sub_question_retrieval_stats( return chunk_stats -def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: +def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalUpdate: sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats( verified_documents=state["verified_documents"], expanded_retrieval_results=state["expanded_retrieval_results"], @@ -259,7 +259,7 @@ def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: # else: # sub_question_retrieval_stats = [sub_question_retrieval_stats] - return ExpandedRetrievalOutput( + return ExpandedRetrievalUpdate( expanded_retrieval_result=ExpandedRetrievalResult( expanded_queries_results=state["expanded_retrieval_results"], all_documents=state["reranked_documents"], @@ -283,6 +283,7 @@ def verification_kickoff( arg=DocVerificationInput( doc_to_verify=doc, question=verification_question, + base_search=False, **in_subgraph_extract_core_fields(state), ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index b6059f8ff3..2129c22eda 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -41,6 +41,10 @@ class DocRerankingUpdate(TypedDict): sub_question_retrieval_stats: RetrievalFitStats | None +class ExpandedRetrievalUpdate(TypedDict): + expanded_retrieval_result: ExpandedRetrievalResult + + ## Graph Output State From 38a616c87caaabd183064f1e59ab676e229e5604 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 30 Dec 2024 18:52:50 -0800 Subject: [PATCH 37/78] basic support for accessing langgraph through the UI --- .../server/query_and_chat/chat_backend.py | 2 + .../ee/onyx/server/query_and_chat/models.py | 8 ++++ .../server/query_and_chat/query_backend.py | 1 + backend/onyx/agent_search/run_graph.py | 37 +++++++++---------- backend/onyx/chat/chat_utils.py | 2 + backend/onyx/chat/process_message.py | 37 ++++++++++++++++++- backend/onyx/server/query_and_chat/models.py | 4 ++ 7 files changed, 70 insertions(+), 21 deletions(-) diff --git a/backend/ee/onyx/server/query_and_chat/chat_backend.py b/backend/ee/onyx/server/query_and_chat/chat_backend.py index 0a29ed0034..8d653ece33 100644 --- a/backend/ee/onyx/server/query_and_chat/chat_backend.py +++ b/backend/ee/onyx/server/query_and_chat/chat_backend.py @@ -179,6 +179,7 @@ def handle_simplified_chat_message( chunks_below=0, full_doc=chat_message_req.full_doc, structured_response_format=chat_message_req.structured_response_format, + use_pro_search=chat_message_req.use_pro_search, ) packets = stream_chat_message_objects( @@ -301,6 +302,7 @@ def handle_send_message_simple_with_history( chunks_below=0, full_doc=req.full_doc, structured_response_format=req.structured_response_format, + use_pro_search=req.use_pro_search, ) packets = stream_chat_message_objects( diff --git a/backend/ee/onyx/server/query_and_chat/models.py b/backend/ee/onyx/server/query_and_chat/models.py index 4726236e01..656ea68921 100644 --- a/backend/ee/onyx/server/query_and_chat/models.py +++ b/backend/ee/onyx/server/query_and_chat/models.py @@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext): # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None + # If True, uses pro search instead of basic search + use_pro_search: bool = False + class BasicCreateChatMessageWithHistoryRequest(ChunkContext): # Last element is the new query. All previous elements are historical context @@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext): # only works if using an OpenAI model. See the following for more details: # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None + # If True, uses pro search instead of basic search + use_pro_search: bool = False class SimpleDoc(BaseModel): @@ -123,6 +128,9 @@ class OneShotQARequest(ChunkContext): # If True, skips generative an AI response to the search query skip_gen_ai_answer_generation: bool = False + # If True, uses pro search instead of basic search + use_pro_search: bool = False + @model_validator(mode="after") def check_persona_fields(self) -> "OneShotQARequest": if self.persona_override_config is None and self.persona_id is None: diff --git a/backend/ee/onyx/server/query_and_chat/query_backend.py b/backend/ee/onyx/server/query_and_chat/query_backend.py index b8e7abd3e4..2a8ddf5085 100644 --- a/backend/ee/onyx/server/query_and_chat/query_backend.py +++ b/backend/ee/onyx/server/query_and_chat/query_backend.py @@ -196,6 +196,7 @@ def get_answer_stream( retrieval_details=query_request.retrieval_options, rerank_settings=query_request.rerank_settings, db_session=db_session, + use_pro_search=query_request.use_pro_search, ) packets = stream_chat_message_objects( diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index de207af628..fe48e6a8b8 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -24,29 +24,16 @@ def _parse_agent_event( Parse the event into a typed object. Return None if we are not interested in the event. """ - # if event["name"] == "LangGraph": - # return None event_type = event["event"] - langgraph_node = event["metadata"].get("langgraph_node", "_graph_") - if "input" in event["data"] and isinstance(event["data"]["input"], str): - input_data = f'\nINPUT: {langgraph_node} -- {str(event["data"]["input"])}' - else: - input_data = "" - if "output" in event["data"] and isinstance(event["data"]["output"], str): - output_data = f'\nOUTPUT: {langgraph_node} -- {str(event["data"]["output"])}' - else: - output_data = "" - if len(input_data) > 0 or len(output_data) > 0: - return input_data + output_data - - event_type = event["event"] - if event_type == "tool_call_kickoff": - return ToolCallKickoff(**event["data"]) - elif event_type == "tool_response": - return ToolResponse(**event["data"]) - elif event_type == "on_chat_model_stream": + if event_type == "on_chat_model_stream": return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content) + elif event_type == "search_result": + # TODO: clean this up (weirdness to make mypy happy) + return ToolResponse( + id=str(event["data"].get("id", "error")), + response=event["data"].get("response", "error"), + ) return None @@ -105,6 +92,16 @@ def run_graph( yield parsed_object +def run_main_graph( + search_request: SearchRequest, + primary_llm: LLM, + fast_llm: LLM, +) -> AnswerStream: + graph = main_graph_builder() + compiled_graph = graph.compile() + return run_graph(compiled_graph, search_request, primary_llm, fast_llm) + + if __name__ == "__main__": from onyx.llm.factory import get_default_llms from onyx.context.search.models import SearchRequest diff --git a/backend/onyx/chat/chat_utils.py b/backend/onyx/chat/chat_utils.py index 70083a0da2..da5b7bfe4a 100644 --- a/backend/onyx/chat/chat_utils.py +++ b/backend/onyx/chat/chat_utils.py @@ -48,6 +48,7 @@ def prepare_chat_message_request( retrieval_details: RetrievalDetails | None, rerank_settings: RerankingDetails | None, db_session: Session, + use_pro_search: bool = False, ) -> CreateChatMessageRequest: # Typically used for one shot flows like SlackBot or non-chat API endpoint use cases new_chat_session = create_chat_session( @@ -72,6 +73,7 @@ def prepare_chat_message_request( search_doc_ids=None, retrieval_options=retrieval_details, rerank_settings=rerank_settings, + use_pro_search=use_pro_search, ) diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 28d67d0323..1ffc2e8bc1 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session +from onyx.agent_search.run_graph import run_main_graph from onyx.chat.answer import Answer from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona @@ -33,11 +34,13 @@ from onyx.configs.constants import MessageType from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import NO_AUTH_USER_ID +from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import OptionalSearchSetting from onyx.context.search.enums import QueryFlow from onyx.context.search.enums import SearchType from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDetails +from onyx.context.search.models import SearchRequest from onyx.context.search.retrieval.search_runner import inference_sections_from_ids from onyx.context.search.utils import chunks_or_sections_to_search_docs from onyx.context.search.utils import dedupe_documents @@ -716,7 +719,39 @@ def stream_chat_message_objects( dropped_indices = None tool_result = None - for packet in answer.processed_streamed_output: + if not new_msg_req.use_pro_search: + answer_stream = answer.processed_streamed_output + else: + search_request = SearchRequest( + query=final_msg.message, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + human_selected_filters=( + retrieval_options.filters if retrieval_options else None + ), + persona=persona, + offset=(retrieval_options.offset if retrieval_options else None), + limit=retrieval_options.limit if retrieval_options else None, + rerank_settings=new_msg_req.rerank_settings, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + enable_auto_detect_filters=( + retrieval_options.enable_auto_detect_filters + if retrieval_options + else None + ), + ) + answer_stream = run_main_graph( + search_request=search_request, + primary_llm=llm, + fast_llm=fast_llm, + ) + + for packet in answer_stream: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: ( diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index 16b7d1b061..185d6cb0e2 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -124,6 +124,10 @@ class CreateChatMessageRequest(ChunkContext): # https://platform.openai.com/docs/guides/structured-outputs/introduction structured_response_format: dict | None = None + # If true, ignores most of the search options and uses pro search instead. + # TODO: decide how many of the above options we want to pass through to pro search + use_pro_search: bool = False + @model_validator(mode="after") def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest": if self.search_doc_ids is None and self.retrieval_options is None: From dd64c3a175ce3c3c07a30599ba2c979401b79f05 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 30 Dec 2024 20:28:50 -0800 Subject: [PATCH 38/78] ugly but workable calling langgraph from the UI --- web/src/app/chat/ChatPage.tsx | 13 +++++++++++++ web/src/app/chat/lib.tsx | 3 +++ 2 files changed, 16 insertions(+) diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 3dde26a8bc..b2f8e0dfe5 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -179,6 +179,7 @@ export function ChatPage({ const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false); const [filtersToggled, setFiltersToggled] = useState(false); + const [langgraphEnabled, setLanggraphEnabled] = useState(false); const [userSettingsToggled, setUserSettingsToggled] = useState(false); @@ -1253,6 +1254,7 @@ export function ChatPage({ systemPromptOverride: searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, useExistingUserMessage: isSeededChat, + useLanggraph: langgraphEnabled, }); const delay = (ms: number) => { @@ -2210,6 +2212,17 @@ export function ChatPage({ llmOverrideManager={llmOverrideManager} /> )} +
+ +
{documentSidebarInitialWidth !== undefined && isReady ? ( diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 2b44da8ce0..e96a1b2314 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -128,6 +128,7 @@ export async function* sendMessage({ useExistingUserMessage, alternateAssistantId, signal, + useLanggraph, }: { regenerate: boolean; message: string; @@ -146,6 +147,7 @@ export async function* sendMessage({ useExistingUserMessage?: boolean; alternateAssistantId?: number; signal?: AbortSignal; + useLanggraph?: boolean; }): AsyncGenerator { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; @@ -186,6 +188,7 @@ export async function* sendMessage({ } : null, use_existing_user_message: useExistingUserMessage, + use_pro_search: useLanggraph, }); const response = await fetch(`/api/chat/send-message`, { From 121827e34c84a9067121833e51fa58869629412c Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 2 Jan 2025 13:49:36 -0800 Subject: [PATCH 39/78] initial end-to-end functioning --- .../answer_follow_up_question/edges.py | 20 ++ .../graph_builder.py | 104 ++++++ .../answer_follow_up_question/models.py | 19 ++ .../nodes/answer_check.py | 30 ++ .../nodes/answer_generation.py | 36 ++ .../nodes/format_answer.py | 28 ++ .../nodes/ingest_retrieval.py | 19 ++ .../answer_follow_up_question/states.py | 58 ++++ .../answer_question/nodes/format_answer.py | 2 +- .../agent_search/answer_question/states.py | 1 + .../nodes/entity_term_extraction.py | 4 +- backend/onyx/agent_search/main/edges.py | 13 + .../onyx/agent_search/main/graph_builder.py | 134 ++++---- backend/onyx/agent_search/main/models.py | 27 ++ backend/onyx/agent_search/main/nodes.py | 316 +++++++++++++++++- backend/onyx/agent_search/main/states.py | 49 ++- .../agent_search/refined_answers/edges.py | 33 ++ .../refined_answers/graph_builder.py | 114 +++++++ .../agent_search/refined_answers/models.py | 8 + .../agent_search/refined_answers/nodes.py | 129 +++++++ .../agent_search/refined_answers/states.py | 19 ++ backend/onyx/agent_search/run_graph.py | 2 +- .../agent_search/shared_graph_utils/models.py | 5 + .../shared_graph_utils/prompts.py | 72 +++- .../agent_search/shared_graph_utils/utils.py | 22 +- 25 files changed, 1181 insertions(+), 83 deletions(-) create mode 100644 backend/onyx/agent_search/answer_follow_up_question/edges.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/graph_builder.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/models.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py create mode 100644 backend/onyx/agent_search/answer_follow_up_question/states.py create mode 100644 backend/onyx/agent_search/main/models.py create mode 100644 backend/onyx/agent_search/refined_answers/edges.py create mode 100644 backend/onyx/agent_search/refined_answers/graph_builder.py create mode 100644 backend/onyx/agent_search/refined_answers/models.py create mode 100644 backend/onyx/agent_search/refined_answers/nodes.py create mode 100644 backend/onyx/agent_search/refined_answers/states.py diff --git a/backend/onyx/agent_search/answer_follow_up_question/edges.py b/backend/onyx/agent_search/answer_follow_up_question/edges.py new file mode 100644 index 0000000000..34bc48b26b --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/edges.py @@ -0,0 +1,20 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.core_state import in_subgraph_extract_core_fields +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput + + +def send_to_expanded_follow_up_retrieval(state: AnswerQuestionInput) -> Send | Hashable: + print("sending to expanded retrieval for follow up question via edge") + + return Send( + "decomposed_follow_up_retrieval", + ExpandedRetrievalInput( + **in_subgraph_extract_core_fields(state), + question=state["question"], + base_search=False + ), + ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py new file mode 100644 index 0000000000..47c75da4aa --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py @@ -0,0 +1,104 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.answer_follow_up_question.edges import ( + send_to_expanded_follow_up_retrieval, +) +from onyx.agent_search.answer_question.nodes.answer_check import answer_check +from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation +from onyx.agent_search.answer_question.nodes.format_answer import format_answer +from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval +from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) + + +def answer_follow_up_query_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=AnswerQuestionState, + input=AnswerQuestionInput, + output=AnswerQuestionOutput, + ) + + ### Add nodes ### + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="decomposed_follow_up_retrieval", + action=expanded_retrieval, + ) + graph.add_node( + node="follow_up_answer_check", + action=answer_check, + ) + graph.add_node( + node="follow_up_answer_generation", + action=answer_generation, + ) + graph.add_node( + node="format_follow_up_answer", + action=format_answer, + ) + graph.add_node( + node="ingest_follow_up_retrieval", + action=ingest_retrieval, + ) + + ### Add edges ### + + graph.add_conditional_edges( + source=START, + path=send_to_expanded_follow_up_retrieval, + path_map=["decomposed_follow_up_retrieval"], + ) + graph.add_edge( + start_key="decomposed_follow_up_retrieval", + end_key="ingest_follow_up_retrieval", + ) + graph.add_edge( + start_key="ingest_follow_up_retrieval", + end_key="follow_up_answer_generation", + ) + graph.add_edge( + start_key="follow_up_answer_generation", + end_key="follow_up_answer_check", + ) + graph.add_edge( + start_key="follow_up_answer_check", + end_key="format_follow_up_answer", + ) + graph.add_edge( + start_key="format_follow_up_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = answer_follow_up_query_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="what can you do with onyx or danswer?", + ) + with get_session_context_manager() as db_session: + inputs = AnswerQuestionInput( + question="what can you do with onyx?", + ) + for thing in compiled_graph.stream( + input=inputs, + # debug=True, + # subgraphs=True, + ): + print(thing) + # output = compiled_graph.invoke(inputs) + # print(output) diff --git a/backend/onyx/agent_search/answer_follow_up_question/models.py b/backend/onyx/agent_search/answer_follow_up_question/models.py new file mode 100644 index 0000000000..ea9fb8f971 --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.context.search.models import InferenceSection + +### Models ### + + +class AnswerRetrievalStats(BaseModel): + answer_retrieval_stats: dict[str, float | int] + + +class QuestionAnswerResults(BaseModel): + question: str + answer: str + quality: str + # expanded_retrieval_results: list[QueryResult] + documents: list[InferenceSection] + sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py new file mode 100644 index 0000000000..6349552f34 --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py @@ -0,0 +1,30 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QACheckUpdate +from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT + + +def answer_check(state: AnswerQuestionState) -> QACheckUpdate: + msg = [ + HumanMessage( + content=SUB_CHECK_PROMPT.format( + question=state["question"], + base_answer=state["answer"], + ) + ) + ] + + fast_llm = state["subgraph_fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + quality_str = merge_message_runs(response, chunk_separator="")[0].content + + return QACheckUpdate( + answer_quality=quality_str, + ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py new file mode 100644 index 0000000000..0403583567 --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py @@ -0,0 +1,36 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QAGenerationUpdate +from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: + question = state["question"] + docs = state["documents"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + original_question=state["subgraph_search_request"].query, + ) + ) + ] + + fast_llm = state["subgraph_fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + answer_str = merge_message_runs(response, chunk_separator="")[0].content + return QAGenerationUpdate( + answer=answer_str, + ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py new file mode 100644 index 0000000000..06977a0ad9 --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py @@ -0,0 +1,28 @@ +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QuestionAnswerResults + + +def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: + # sub_question_retrieval_stats = state["sub_question_retrieval_stats"] + # if sub_question_retrieval_stats is None: + # sub_question_retrieval_stats = [] + # elif isinstance(sub_question_retrieval_stats, list): + # sub_question_retrieval_stats = sub_question_retrieval_stats + # if isinstance(sub_question_retrieval_stats[0], list): + # sub_question_retrieval_stats = sub_question_retrieval_stats[0] + # else: + # sub_question_retrieval_stats = [sub_question_retrieval_stats] + + return AnswerQuestionOutput( + answer_results=[ + QuestionAnswerResults( + question=state["question"], + quality=state["answer_quality"], + answer=state["answer"], + # expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + sub_question_retrieval_stats=state["sub_question_retrieval_stats"], + ) + ], + ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py new file mode 100644 index 0000000000..cc9e5989ff --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py @@ -0,0 +1,19 @@ +from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats + + +def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: + sub_question_retrieval_stats = state[ + "expanded_retrieval_result" + ].sub_question_retrieval_stats + if sub_question_retrieval_stats is None: + sub_question_retrieval_stats = [AgentChunkStats()] + + return RetrievalIngestionUpdate( + expanded_retrieval_results=state[ + "expanded_retrieval_result" + ].expanded_queries_results, + documents=state["expanded_retrieval_result"].all_documents, + sub_question_retrieval_stats=sub_question_retrieval_stats, + ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/states.py b/backend/onyx/agent_search/answer_follow_up_question/states.py new file mode 100644 index 0000000000..28f4dc2134 --- /dev/null +++ b/backend/onyx/agent_search/answer_follow_up_question/states.py @@ -0,0 +1,58 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from onyx.agent_search.answer_question.models import QuestionAnswerResults +from onyx.agent_search.core_state import SubgraphCoreState +from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.shared_graph_utils.models import AgentChunkStats +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +## Update States +class QACheckUpdate(TypedDict): + answer_quality: str + + +class QAGenerationUpdate(TypedDict): + answer: str + # answer_stat: AnswerStats + + +class RetrievalIngestionUpdate(TypedDict): + expanded_retrieval_results: list[QueryResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + sub_question_retrieval_stats: AgentChunkStats + + +## Graph Input State + + +class AnswerQuestionInput(SubgraphCoreState): + question: str + + +## Graph State + + +class AnswerQuestionState( + AnswerQuestionInput, + QAGenerationUpdate, + QACheckUpdate, + RetrievalIngestionUpdate, +): + pass + + +## Graph Output State + + +class AnswerQuestionOutput(TypedDict): + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 06977a0ad9..902e0d4924 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -18,7 +18,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: answer_results=[ QuestionAnswerResults( question=state["question"], - quality=state["answer_quality"], + quality=state.get("answer_quality", "No"), answer=state["answer"], # expanded_retrieval_results=state["expanded_retrieval_results"], documents=state["documents"], diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 28f4dc2134..b5b9f0880d 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -31,6 +31,7 @@ class RetrievalIngestionUpdate(TypedDict): class AnswerQuestionInput(SubgraphCoreState): question: str + question_nr: int ## Graph State diff --git a/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py b/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py index 865a78f0a7..0ab2cca35d 100644 --- a/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py +++ b/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py @@ -13,8 +13,8 @@ def entity_term_extraction(state: MainState) -> dict[str, Any]: """Extract entities and terms from the question and context""" - question = state["original_question"] - docs = state["deduped_retrieval_docs"] + question = state["search_request"].query + docs = state["base_raw_search_result"] doc_context = format_docs(docs) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 22cba21933..bd62b5e3db 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -1,5 +1,7 @@ from collections.abc import Hashable +from typing import Literal +from langgraph.graph import END from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput @@ -8,6 +10,7 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState +from onyx.agent_search.main.states import RequireRefinedAnswerUpdate def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: @@ -48,6 +51,16 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: ] +# Define the function that determines whether to continue or not +def continue_to_refined_answer_or_end( + state: RequireRefinedAnswerUpdate, +) -> Literal["refined_answer_subgraph", "END"]: + if state["require_refined_answer"]: + return "refined_answer_subgraph" + else: + return END + + # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: # # Routes re-written queries to the (parallel) retrieval steps # # Notice the 'Send()' API that takes care of the parallelization diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 50a839cb21..dcd1cf8647 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -6,13 +6,23 @@ from onyx.agent_search.base_raw_search.graph_builder import ( base_raw_search_graph_builder, ) +from onyx.agent_search.main.edges import continue_to_refined_answer_or_end from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries +from onyx.agent_search.main.nodes import entity_term_extraction from onyx.agent_search.main.nodes import generate_initial_answer +from onyx.agent_search.main.nodes import generate_refined_answer from onyx.agent_search.main.nodes import ingest_answers from onyx.agent_search.main.nodes import ingest_initial_retrieval +from onyx.agent_search.main.nodes import initial_answer_quality_check from onyx.agent_search.main.nodes import main_decomp_base +from onyx.agent_search.main.nodes import refined_answer_decision from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState +from onyx.agent_search.refined_answers.graph_builder import ( + refined_answers_graph_builder, +) + +# from onyx.agent_search.main.nodes import check_refined_answer test_mode = False @@ -305,22 +315,28 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: action=answer_query_subgraph, ) - # graph.add_node( - # node="prep_for_initial_retrieval", - # action=prep_for_initial_retrieval, - # ) - - # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() - # graph.add_node( - # node="initial_retrieval", - # action=expanded_retrieval_subgraph, - # ) - base_raw_search_subgraph = base_raw_search_graph_builder().compile() graph.add_node( node="base_raw_search_data", action=base_raw_search_subgraph, ) + + refined_answer_subgraph = refined_answers_graph_builder().compile() + graph.add_node( + node="refined_answer_subgraph", + action=refined_answer_subgraph, + ) + + graph.add_node( + node="generate_refined_answer", + action=generate_refined_answer, + ) + + # graph.add_node( + # node="check_refined_answer", + # action=check_refined_answer, + # ) + graph.add_node( node="ingest_initial_retrieval", action=ingest_initial_retrieval, @@ -333,6 +349,20 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: node="generate_initial_answer", action=generate_initial_answer, ) + + graph.add_node( + node="initial_answer_quality_check", + action=initial_answer_quality_check, + ) + + graph.add_node( + node="entity_term_extraction", + action=entity_term_extraction, + ) + graph.add_node( + node="refined_answer_decision", + action=refined_answer_decision, + ) # if test_mode: # graph.add_node( # node="generate_initial_base_answer", @@ -341,39 +371,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: ### Add edges ### - # graph.add_conditional_edges( - # source=START, - # path=send_to_initial_retrieval, - # path_map=["initial_retrieval"], - # ) - - # graph.add_edge( - # start_key=START, - # end_key="prep_for_initial_retrieval", - # ) - # graph.add_edge( - # start_key="prep_for_initial_retrieval", - # end_key="initial_retrieval", - # ) - # graph.add_edge( - # start_key="initial_retrieval", - # end_key="ingest_initial_retrieval", - # ) - graph.add_edge(start_key=START, end_key="base_raw_search_data") - # # graph.add_edge( - # # start_key="base_raw_search_data", - # # end_key=END - # # ) graph.add_edge( start_key="base_raw_search_data", end_key="ingest_initial_retrieval", ) - # graph.add_edge( - # start_key="ingest_initial_retrieval", - # end_key=END - # ) + graph.add_edge( start_key=START, end_key="base_decomp", @@ -388,38 +392,50 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: end_key="ingest_answers", ) - # graph.add_edge( - # start_key="ingest_answers", - # end_key="generate_initial_answer", - # ) - graph.add_edge( start_key=["ingest_answers", "ingest_initial_retrieval"], end_key="generate_initial_answer", ) + graph.add_edge( + start_key=["ingest_answers", "ingest_initial_retrieval"], + end_key="entity_term_extraction", + ) + graph.add_edge( start_key="generate_initial_answer", + end_key="initial_answer_quality_check", + ) + + graph.add_edge( + start_key=["initial_answer_quality_check", "entity_term_extraction"], + end_key="refined_answer_decision", + ) + + graph.add_conditional_edges( + source="refined_answer_decision", + path=continue_to_refined_answer_or_end, + path_map=["refined_answer_subgraph", END], + ) + + graph.add_edge( + start_key="refined_answer_subgraph", + end_key="generate_refined_answer", + ) + graph.add_edge( + start_key="generate_refined_answer", end_key=END, ) + # graph.add_edge( - # start_key="ingest_answers", - # end_key="generate_initial_answer", + # start_key="generate_refined_answer", + # end_key="check_refined_answer", + # ) + + # graph.add_edge( + # start_key="check_refined_answer", + # end_key=END, # ) - # if test_mode: - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_base_answer", - # ) - # graph.add_edge( - # start_key=["generate_initial_answer", "generate_initial_base_answer"], - # end_key=END, - # ) - # else: - # graph.add_edge( - # start_key="generate_initial_answer", - # end_key=END, - # ) return graph diff --git a/backend/onyx/agent_search/main/models.py b/backend/onyx/agent_search/main/models.py new file mode 100644 index 0000000000..0f011af3dd --- /dev/null +++ b/backend/onyx/agent_search/main/models.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + + +### Models ### + + +class Entity(BaseModel): + entity_name: str + entity_type: str + + +class Relationship(BaseModel): + relationship_name: str + relationship_type: str + relationship_entities: list[str] + + +class Term(BaseModel): + term_name: str + term_type: str + term_similar_to: list[str] + + +class EntityRelationshipTermExtraction(BaseModel): + entities: list[Entity] + relationships: list[Relationship] + terms: list[Term] diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 5fba915c56..c4a624dd24 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -1,17 +1,31 @@ +import json +import re + from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.main.models import Entity +from onyx.agent_search.main.models import EntityRelationshipTermExtraction +from onyx.agent_search.main.models import Relationship +from onyx.agent_search.main.models import Term from onyx.agent_search.main.states import BaseDecompUpdate from onyx.agent_search.main.states import DecompAnswersUpdate +from onyx.agent_search.main.states import EntityTermExtractionUpdate from onyx.agent_search.main.states import ExpandedRetrievalUpdate from onyx.agent_search.main.states import InitialAnswerBASEUpdate +from onyx.agent_search.main.states import InitialAnswerQualityUpdate from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState +from onyx.agent_search.main.states import RefinedAnswerUpdate +from onyx.agent_search.main.states import RequireRefinedAnswerUpdate from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats +from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from onyx.agent_search.shared_graph_utils.prompts import ( INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, ) @@ -20,6 +34,7 @@ from onyx.agent_search.shared_graph_utils.prompts import ( INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, ) +from onyx.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string from onyx.agent_search.shared_graph_utils.utils import format_docs @@ -115,10 +130,11 @@ def _calculate_initial_agent_stats( if ( orig_support_score is None + or orig_support_score == 0.0 and initial_agent_result_stats.sub_questions["verified_avg_score"] is None ): initial_agent_result_stats.agent_effectiveness["support_ratio"] = None - elif orig_support_score is None: + elif orig_support_score is None or orig_support_score == 0.0: initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10 elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None: initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0 @@ -216,6 +232,124 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: ) +def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate: + """ + Check whether the final output satisfies the original user question + + Args: + state (messages): The current state + + Returns: + InitialAnswerQualityUpdate + """ + + # print("---CHECK INITIAL QUTPUT QUALITY---") + + # question = state["search_request"].query + # initial_answer = state["initial_answer"] + + # msg = [ + # HumanMessage( + # content=BASE_CHECK_PROMPT.format(question=question, initial_answer=initial_answer) + # ) + # ] + + # model = state["fast_llm"] + # response = model.invoke(msg) + + # if 'yes' in response.content.lower(): + # verdict = True + # else: + # verdict = False + + # print(f"Verdict: {verdict}") + + print("Checking for base answer validity - for not set True/False manually") + + verdict = True + + return InitialAnswerQualityUpdate(initial_answer_quality=verdict) + + +def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: + print("---GENERATE ENTITIES & TERMS---") + + # first four lines duplicates from generate_initial_answer + question = state["search_request"].query + sub_question_docs = state["documents"] + all_original_question_documents = state["all_original_question_documents"] + relevant_docs = dedup_inference_sections( + sub_question_docs, all_original_question_documents + ) + + # start with the entity/term/extraction + + doc_context = format_docs(relevant_docs) + + msg = [ + HumanMessage( + content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), + ) + ] + fast_llm = state["fast_llm"] + # Grader + llm_response_list = list( + fast_llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + cleaned_response = re.sub(r"```json\n|\n```", "", llm_response) + parsed_response = json.loads(cleaned_response) + + entities = [] + relationships = [] + terms = [] + for entity in parsed_response.get("retrieved_entities_relationships", {}).get( + "entities", {} + ): + entity_name = entity.get("entity_name", "") + entity_type = entity.get("entity_type", "") + entities.append(Entity(entity_name=entity_name, entity_type=entity_type)) + + for relationship in parsed_response.get("retrieved_entities_relationships", {}).get( + "relationships", {} + ): + relationship_name = relationship.get("relationship_name", "") + relationship_type = relationship.get("relationship_type", "") + relationship_entities = relationship.get("relationship_entities", []) + relationships.append( + Relationship( + relationship_name=relationship_name, + relationship_type=relationship_type, + relationship_entities=relationship_entities, + ) + ) + + for term in parsed_response.get("retrieved_entities_relationships", {}).get( + "terms", {} + ): + term_name = term.get("term_name", "") + term_type = term.get("term_type", "") + term_similar_to = term.get("term_similar_to", []) + terms.append( + Term( + term_name=term_name, + term_type=term_type, + term_similar_to=term_similar_to, + ) + ) + + return EntityTermExtractionUpdate( + entity_retlation_term_extractions=EntityRelationshipTermExtraction( + entities=entities, + relationships=relationships, + terms=terms, + ) + ) + + def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: print("---GENERATE INITIAL BASE ANSWER---") @@ -274,3 +408,183 @@ def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpd ].all_documents, original_question_retrieval_stats=sub_question_retrieval_stats, ) + + +def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: + print("---REFINED ANSWER DECISION---") + + if False: + return RequireRefinedAnswerUpdate(require_refined_answer=False) + + else: + return RequireRefinedAnswerUpdate(require_refined_answer=True) + + +def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: + print("---GENERATE REFINED ANSWER---") + + initial_documents = state["documents"] + revised_documents = state["follow_up_documents"] + + combined_documents = dedup_inference_sections(initial_documents, revised_documents) + + if len(initial_documents) > 0: + revision_doc_effectiveness = len(combined_documents) / len(initial_documents) + elif len(revised_documents) == 0: + revision_doc_effectiveness = 0.0 + else: + revision_doc_effectiveness = 10.0 + + question = state["search_request"].query + + decomp_answer_results = state["decomp_answer_results"] + revised_answer_results = state["follow_up_decomp_answer_results"] + + good_qa_list: list[str] = [] + decomp_questions = [] + + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + + initial_good_sub_questions: list[str] = [] + new_revised_good_sub_questions: list[str] = [] + + for answer_set in [decomp_answer_results, revised_answer_results]: + for decomp_answer_result in answer_set: + decomp_questions.append(decomp_answer_result.question) + if ( + decomp_answer_result.quality.lower().startswith("yes") + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != "I don't know" + ): + good_qa_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.question, + sub_answer=decomp_answer_result.answer, + ) + ) + if answer_set == decomp_answer_results: + initial_good_sub_questions.append(decomp_answer_result.question) + else: + new_revised_good_sub_questions.append(decomp_answer_result.question) + + initial_good_sub_questions = list(set(initial_good_sub_questions)) + new_revised_good_sub_questions = list(set(new_revised_good_sub_questions)) + total_good_sub_questions = list( + set(initial_good_sub_questions + new_revised_good_sub_questions) + ) + revision_question_efficiency = len(total_good_sub_questions) / len( + initial_good_sub_questions + ) + + sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list))) + + # original answer + + initial_answer = state["initial_answer"] + + if len(good_qa_list) > 0: + msg = [ + HumanMessage( + content=REVISED_RAG_PROMPT.format( + question=question, + answered_sub_questions=sub_question_answer_str, + relevant_docs=format_docs(combined_documents), + initial_answer=initial_answer, + ) + ) + ] + else: + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format( + question=question, + relevant_docs=format_docs(combined_documents), + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + answer = response.pretty_repr() + + # refined_agent_stats = _calculate_refined_agent_stats( + # state["decomp_answer_results"], state["original_question_retrieval_stats"] + # ) + + initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions))) + new_revised_good_sub_questions_str = "\n".join( + list(set(new_revised_good_sub_questions)) + ) + + refined_agent_stats = RefinedAgentStats( + revision_doc_efficiency=revision_doc_effectiveness, + revision_question_efficiency=revision_question_efficiency, + ) + + print(f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}") + print("-" * 10) + print(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + + print("-" * 100) + print(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n") + print("-" * 10) + print(f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n") + + print("-" * 100) + + print( + f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n" + ) + + print("-" * 100) + + if state["initial_agent_stats"]: + initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get( + "utilized_chunk_ratio", "--" + ) + initial_support_boost_factor = state[ + "initial_agent_stats" + ].agent_effectiveness.get("support_ratio", "--") + initial_verified_docs = state["initial_agent_stats"].original_question.get( + "num_verified_documents", "--" + ) + initial_verified_docs_avg_score = state[ + "initial_agent_stats" + ].original_question.get("verified_avg_score", "--") + initial_sub_questions_verified_docs = state[ + "initial_agent_stats" + ].sub_questions.get("num_verified_documents", "--") + + print("INITIAL AGENT STATS") + print(f"Document Boost Factor: {initial_doc_boost_factor}") + print(f"Support Boost Factor: {initial_support_boost_factor}") + print(f"Originally Verified Docs: {initial_verified_docs}") + print(f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}") + print(f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}") + if refined_agent_stats: + print("-" * 10) + print("REFINED AGENT STATS") + print(f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}") + print( + f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}" + ) + + print("\n\n ---INITIAL AGENT ANSWER END---\n\n") + + return RefinedAnswerUpdate( + refined_answer=answer, + refined_answer_quality=True, # TODO: replace this with the actual check value + refined_agent_stats=refined_agent_stats, + ) + + +# def check_refined_answer(state: MainState) -> RefinedAnswerUpdate: +# print("---CHECK REFINED ANSWER---") + +# return RefinedAnswerUpdate( +# refined_answer="", +# refined_answer_quality=True +# ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index ff5f87b76d..9c90448c47 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -6,12 +6,14 @@ from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.main.models import EntityRelationshipTermExtraction +from onyx.agent_search.refined_answers.models import FollowUpSubQuestion from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats +from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection - ### States ### ## Update States @@ -31,11 +33,30 @@ class InitialAnswerUpdate(TypedDict): generated_sub_questions: list[str] +class RefinedAnswerUpdate(TypedDict): + refined_answer: str + refined_agent_stats: RefinedAgentStats | None + refined_answer_quality: bool + + +class InitialAnswerQualityUpdate(TypedDict): + initial_answer_quality: bool + + +class RequireRefinedAnswerUpdate(TypedDict): + require_refined_answer: bool + + class DecompAnswersUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] decomp_answer_results: Annotated[list[QuestionAnswerResults], add] +class FollowUpDecompAnswersUpdate(TypedDict): + follow_up_documents: Annotated[list[InferenceSection], dedup_inference_sections] + follow_up_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] + + class ExpandedRetrievalUpdate(TypedDict): all_original_question_documents: Annotated[ list[InferenceSection], dedup_inference_sections @@ -44,6 +65,24 @@ class ExpandedRetrievalUpdate(TypedDict): original_question_retrieval_stats: AgentChunkStats +class EntityTermExtractionUpdate(TypedDict): + entity_retlation_term_extractions: EntityRelationshipTermExtraction + + +class FollowUpSubQuestionsUpdate(TypedDict): + follow_up_sub_questions: dict[int, FollowUpSubQuestion] + + +class FollowUpAnswerQuestionOutput(TypedDict): + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + follow_up_answer_results: Annotated[list[QuestionAnswerResults], add] + + ## Graph Input State @@ -62,6 +101,13 @@ class MainState( InitialAnswerBASEUpdate, DecompAnswersUpdate, ExpandedRetrievalUpdate, + EntityTermExtractionUpdate, + InitialAnswerQualityUpdate, + RequireRefinedAnswerUpdate, + FollowUpSubQuestionsUpdate, + FollowUpAnswerQuestionOutput, + FollowUpDecompAnswersUpdate, + RefinedAnswerUpdate, ): # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add] @@ -75,3 +121,4 @@ class MainOutput(TypedDict): initial_base_answer: str initial_agent_stats: dict generated_sub_questions: list[str] + require_refined_answer: bool diff --git a/backend/onyx/agent_search/refined_answers/edges.py b/backend/onyx/agent_search/refined_answers/edges.py new file mode 100644 index 0000000000..2fa657cdb5 --- /dev/null +++ b/backend/onyx/agent_search/refined_answers/edges.py @@ -0,0 +1,33 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.core_state import extract_core_fields_for_subgraph +from onyx.agent_search.main.states import MainState + + +def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashable]: + if len(state["follow_up_sub_questions"]) > 0: + return [ + Send( + "answer_follow_up_question", + AnswerQuestionInput( + **extract_core_fields_for_subgraph(state), + question=question_data.sub_question, + question_nr=question_nr, + ), + ) + for question_nr, question_data in state["follow_up_sub_questions"].items() + ] + + else: + return [ + Send( + "ingest_follow_up_answers", + AnswerQuestionOutput( + answer_results=[], + ), + ) + ] diff --git a/backend/onyx/agent_search/refined_answers/graph_builder.py b/backend/onyx/agent_search/refined_answers/graph_builder.py new file mode 100644 index 0000000000..e628caef3a --- /dev/null +++ b/backend/onyx/agent_search/refined_answers/graph_builder.py @@ -0,0 +1,114 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.answer_follow_up_question.graph_builder import ( + answer_follow_up_query_graph_builder, +) +from onyx.agent_search.refined_answers.edges import parallelize_follow_up_answer_queries +from onyx.agent_search.refined_answers.nodes import dummy_node +from onyx.agent_search.refined_answers.nodes import follow_up_decompose +from onyx.agent_search.refined_answers.nodes import ingest_follow_up_answers +from onyx.agent_search.refined_answers.states import RefinedAnswerInput +from onyx.agent_search.refined_answers.states import RefinedAnswerOutput +from onyx.agent_search.refined_answers.states import RefinedAnswerState + + +def refined_answers_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=RefinedAnswerState, + input=RefinedAnswerInput, + output=RefinedAnswerOutput, + ) + + ### Add nodes ### + + graph.add_node( + node="dummy_node", + action=dummy_node, + ) + + graph.add_node( + node="follow_up_decompose", + action=follow_up_decompose, + ) + + answer_follow_up_question = answer_follow_up_query_graph_builder().compile() + graph.add_node( + node="answer_follow_up_question", + action=answer_follow_up_question, + ) + + graph.add_node( + node="ingest_follow_up_answers", + action=ingest_follow_up_answers, + ) + + # graph.add_node( + # node="format_follow_up_answer", + # action=format_follow_up_answer, + # ) + + ### Add edges ### + + graph.add_edge(start_key=START, end_key="dummy_node") + + graph.add_edge( + start_key="dummy_node", + end_key="follow_up_decompose", + ) + + graph.add_conditional_edges( + source="follow_up_decompose", + path=parallelize_follow_up_answer_queries, + path_map=["answer_follow_up_question"], + ) + graph.add_edge( + start_key="answer_follow_up_question", + end_key="ingest_follow_up_answers", + ) + + # graph.add_conditional_edges( + # start_key="answer_follow_up_question", + # end_key="ingest_follow_up_answers", + # ) + + # graph.add_conditional_edges( + # start_key="ingest_follow_up_answers", + # end_key="format_follow_up_answer", + # ) + + # graph.add_edge( + # start_key="format_follow_up_answer", + # end_key="generate_refined_answer", + # ) + + # graph.add_edge( + # start_key="generate_refined_answer", + # end_key="refined_answer_quality_check", + # ) + + # graph.add_edge( + # start_key="refined_answer_quality_check", + # end_key=END, + # ) + + # graph.add_edge( + # start_key="ingest_follow_up_answers", + # end_key="format_follow_up_answer", + # ) + # graph.add_edge( + # start_key="format_follow_up_answer", + # end_key=END, + # ) + + graph.add_edge( + start_key="ingest_follow_up_answers", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + pass diff --git a/backend/onyx/agent_search/refined_answers/models.py b/backend/onyx/agent_search/refined_answers/models.py new file mode 100644 index 0000000000..23347f6398 --- /dev/null +++ b/backend/onyx/agent_search/refined_answers/models.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class FollowUpSubQuestion(BaseModel): + sub_question: str + verified: bool + answered: bool + answer: str diff --git a/backend/onyx/agent_search/refined_answers/nodes.py b/backend/onyx/agent_search/refined_answers/nodes.py new file mode 100644 index 0000000000..e9e5afa2fc --- /dev/null +++ b/backend/onyx/agent_search/refined_answers/nodes.py @@ -0,0 +1,129 @@ +import json +import re + +from langchain_core.messages import HumanMessage + +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QuestionAnswerResults +from onyx.agent_search.main.states import FollowUpAnswerQuestionOutput +from onyx.agent_search.main.states import FollowUpDecompAnswersUpdate +from onyx.agent_search.main.states import FollowUpSubQuestionsUpdate +from onyx.agent_search.main.states import MainState +from onyx.agent_search.refined_answers.models import FollowUpSubQuestion +from onyx.agent_search.refined_answers.states import RefinedAnswerInput +from onyx.agent_search.refined_answers.states import RefinedAnswerOutput +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction + + +def dummy_node(state: RefinedAnswerInput) -> RefinedAnswerOutput: + print("---DUMMY NODE---") + return {"dummy_output": "this is a dummy output"} + + +def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: + """ """ + + question = state["search_request"].query + base_answer = state["initial_answer"] + + # get the entity term extraction dict and properly format it + entity_retlation_term_extractions = state["entity_retlation_term_extractions"] + + entity_term_extraction_str = format_entity_term_extraction( + entity_retlation_term_extractions + ) + + initial_question_answers = state["decomp_answer_results"] + + addressed_question_list = [ + x.question for x in initial_question_answers if "yes" in x.quality.lower() + ] + + failed_question_list = [ + x.question for x in initial_question_answers if "no" in x.quality.lower() + ] + + msg = [ + HumanMessage( + content=DEEP_DECOMPOSE_PROMPT.format( + question=question, + entity_term_extraction_str=entity_term_extraction_str, + base_answer=base_answer, + answered_sub_questions="\n - ".join(addressed_question_list), + failed_sub_questions="\n - ".join(failed_question_list), + ), + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + if isinstance(response.content, str): + cleaned_response = re.sub(r"```json\n|\n```", "", response.content) + parsed_response = json.loads(cleaned_response) + else: + raise ValueError("LLM response is not a string") + + follow_up_sub_question_dict = {} + for sub_question_nr, sub_question_dict in enumerate( + parsed_response["sub_questions"] + ): + follow_up_sub_question = FollowUpSubQuestion( + sub_question=sub_question_dict["sub_question"], + verified=False, + answered=False, + answer="", + ) + + follow_up_sub_question_dict[sub_question_nr] = follow_up_sub_question + + return FollowUpSubQuestionsUpdate( + follow_up_sub_questions=follow_up_sub_question_dict + ) + + +def ingest_follow_up_answers( + state: AnswerQuestionOutput, +) -> FollowUpDecompAnswersUpdate: + documents = [] + answer_results = state.get("answer_results", []) + for answer_result in answer_results: + documents.extend(answer_result.documents) + return FollowUpDecompAnswersUpdate( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + follow_up_documents=dedup_inference_sections(documents, []), + follow_up_decomp_answer_results=answer_results, + ) + + +def format_follow_up_answer(state: AnswerQuestionState) -> FollowUpAnswerQuestionOutput: + return FollowUpAnswerQuestionOutput( + follow_up_answer_results=[ + QuestionAnswerResults( + question=state["question"], + quality=state.get("answer_quality", "No"), + answer=state["answer"], + # expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + sub_question_retrieval_stats=state["sub_question_retrieval_stats"], + ) + ], + ) + + +# def ingest_follow_up_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: +# documents = [] +# answer_results = state.get("answer_results", []) +# for answer_result in answer_results: +# documents.extend(answer_result.documents) +# return DecompAnswersUpdate( +# # Deduping is done by the documents operator for the main graph +# # so we might not need to dedup here +# documents=dedup_inference_sections(documents, []), +# decomp_answer_results=answer_results, +# ) diff --git a/backend/onyx/agent_search/refined_answers/states.py b/backend/onyx/agent_search/refined_answers/states.py new file mode 100644 index 0000000000..2f393355b3 --- /dev/null +++ b/backend/onyx/agent_search/refined_answers/states.py @@ -0,0 +1,19 @@ +from typing import TypedDict + +from onyx.agent_search.main.states import MainState + + +class RefinedAnswerInput(MainState): + pass + + +class RefinedAnswerOutput(TypedDict): + dummy_output: str + + +class FollowUpSubQuestionsUpdate(TypedDict): + follow_up_sub_question_dict: dict[str, dict[str, str]] + + +class RefinedAnswerState(RefinedAnswerInput, RefinedAnswerOutput): + pass diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index de207af628..5078a7c8b4 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -113,7 +113,7 @@ def run_graph( compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="what can you do with onyx or danswer?", + query="What are the guiding principles behind the development of cockroachDB?", ) for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): print("a") diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py index 5193b5dd2b..61e44d03f4 100644 --- a/backend/onyx/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -45,3 +45,8 @@ class InitialAgentResultStats(BaseModel): sub_questions: dict[str, float | int | None] original_question: dict[str, float | int | None] agent_effectiveness: dict[str, float | int | None] + + +class RefinedAgentStats(BaseModel): + revision_doc_efficiency: float | int | None + revision_question_efficiency: float | int | None diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 07b935d91b..d45c36b1ec 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -66,7 +66,7 @@ \n ------- \n Here is the proposed answer: \n ------- \n - {base_answer} + {initial_answer} \n ------- \n Please answer with yes or no:""" @@ -221,7 +221,7 @@ were not directly answerable. Also, some entities, relationships and terms are givenm to you so that you have an idea of how the avaiolable data looks like. - Your role is to generate 4-6 new sub-questions that would help to answer the initial question, + Your role is to generate 3-5 new sub-questions that would help to answer the initial question, considering: 1) The initial question @@ -249,6 +249,8 @@ - good sub-question: "What is the name of the river that flows through Paris?" - For each sub-question, please also provide a search term that can be used to retrieve relevant documents from a document store. + - Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not + answerable with the available context, and you should not ask similar questions. \n\n Here is the initial question: \n ------- \n @@ -276,14 +278,19 @@ \n ------- \n Please generate the list of good, fully contextualized sub-questions that would help to address the - main question. Again, please find questions that are NOT overlapping too much with the already answered + main question. + + Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of + objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not + mentioned in the 'entities, relationships and terms' section. + + Again, please find questions that are NOT overlapping too much with the already answered sub-questions or those that already were suggested and failed. In other words - what can we try in addition to what has been tried so far? Generate the list of json dictionaries with the following format: - {{"sub_questions": [{{"sub_question": , - "search_term": }}, + {{"sub_questions": [{{"sub_question": }}, ...]}} """ DECOMPOSE_PROMPT = """ \n @@ -310,8 +317,7 @@ - good sub-question: "What is the name of the river that flows through Paris?" - For each sub-question, please provide a short explanation for why it is a good sub-question. So generate a list of dictionaries with the following format: - [{{"sub_question": , "explanation": , "search_term": }}, ...] + [{{"sub_question": , "explanation": }}, ...] \n\n Here is the initial question: @@ -479,6 +485,50 @@ \n--\n\n Answer:""" +REVISED_RAG_PROMPT = """ \n +You are an assistant for question-answering tasks. Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists of: + 1) an initial answer that was given but found to be lacking in some way. + 2) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 3) a number of documents that were also deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say "I don't know". + - If the information is relevant but not fully conclusive, provide and answer to the extent you can but also + specify that the information is not conclusive and why. + +Again, you should be sure that the answer is supported by the information provided! + +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. + +Here is the contextual information: +\n-------\n + +*Initial Answer that was found to be lacking: +{initial_answer} + +*Answered Sub-questions (these should really matter! They also contain questions/answers that were not available when the original +answer was constructed): +{answered_sub_questions} + +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n + +{relevant_docs} + +\n-------\n +\n +Lastly, here is the question I want you to answer based on the information above: +\n--\n +{question} +\n--\n\n +Answer:""" + INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """ You are an assistant for question-answering tasks. Use the information provided below - and only the provided information - to answer the provided question. @@ -529,14 +579,14 @@ "entity_type": }}], "relationships": [{{ - "name": , - "type": , - "entities": [, ] + "relationship_name": , + "relationship_type": , + "relationship_entities": [, , ...] }}], "terms": [{{ "term_name": , "term_type": , - "similar_to": + "term_similar_to": }}] }} }} diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index a435860320..ba4f6d4228 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -6,6 +6,7 @@ from datetime import timedelta from typing import Any +from onyx.agent_search.main.models import EntityRelationshipTermExtraction from onyx.context.search.models import InferenceSection @@ -49,28 +50,35 @@ def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: return json.loads(cleaned_string) -def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str: - entities = entity_term_extraction_dict["entities"] - terms = entity_term_extraction_dict["terms"] - relationships = entity_term_extraction_dict["relationships"] +def format_entity_term_extraction( + entity_term_extraction_dict: EntityRelationshipTermExtraction, +) -> str: + entities = entity_term_extraction_dict.entities + terms = entity_term_extraction_dict.terms + relationships = entity_term_extraction_dict.relationships entity_strs = ["\nEntities:\n"] for entity in entities: - entity_str = f"{entity['entity_name']} ({entity['entity_type']})" + entity_str = f"{entity.entity_name} ({entity.entity_type})" entity_strs.append(entity_str) entity_str = "\n - ".join(entity_strs) relationship_strs = ["\n\nRelationships:\n"] for relationship in relationships: - relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}" + relationship_name = relationship.relationship_name + relationship_type = relationship.relationship_type + relationship_entities = relationship.relationship_entities + relationship_str = ( + f"""{relationship_name} ({relationship_type}): {relationship_entities}""" + ) relationship_strs.append(relationship_str) relationship_str = "\n - ".join(relationship_strs) term_strs = ["\n\nTerms:\n"] for term in terms: - term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}" + term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}" term_strs.append(term_str) term_str = "\n - ".join(term_strs) From d773163502c60ccbe27d574fba5605828348f7c4 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 2 Jan 2025 14:29:08 -0800 Subject: [PATCH 40/78] unorganized streaming of all relevant info --- .../nodes/answer_generation.py | 15 +- .../nodes/generate_raw_search_data.py | 1 + backend/onyx/agent_search/core_state.py | 4 + .../agent_search/expanded_retrieval/nodes.py | 184 ++++++++++-------- backend/onyx/agent_search/main/nodes.py | 35 +++- backend/onyx/agent_search/run_graph.py | 118 +++++++++-- backend/onyx/chat/answer.py | 42 ++-- backend/onyx/chat/models.py | 20 ++ backend/onyx/chat/process_message.py | 65 +++---- .../search/search_tool.py | 4 + 10 files changed, 332 insertions(+), 156 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 0403583567..b25c9f6072 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -1,3 +1,6 @@ +from typing import Any + +from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs @@ -24,11 +27,15 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: ] fast_llm = state["subgraph_fast_llm"] - response = list( - fast_llm.stream( - prompt=msg, + response: list[str | list[str | dict[str, Any]]] = [] + for message in fast_llm.stream( + prompt=msg, + ): + dispatch_custom_event( + "sub_answers", + message.content, ) - ) + response.append(message.content) answer_str = merge_message_runs(response, chunk_separator="")[0].content return QAGenerationUpdate( diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index 9b6a7b1485..259105e428 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -11,4 +11,5 @@ def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: subgraph_db_session=state["db_session"], question=state["search_request"].query, base_search=True, + subgraph_search_tool=state["search_tool"], ) diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index a035d25b31..2868e2176f 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -7,6 +7,7 @@ from onyx.context.search.models import SearchRequest from onyx.llm.interfaces import LLM +from onyx.tools.tool_implementations.search.search_tool import SearchTool class CoreState(TypedDict, total=False): @@ -21,6 +22,7 @@ class CoreState(TypedDict, total=False): # is fine if we are only reading db_session: Session log_messages: Annotated[list[str], add] + search_tool: SearchTool class SubgraphCoreState(TypedDict, total=False): @@ -35,6 +37,8 @@ class SubgraphCoreState(TypedDict, total=False): # is fine if we are only reading subgraph_db_session: Session + subgraph_search_tool: SearchTool + # This ensures that the state passed in extends the CoreState T = TypeVar("T", bound=CoreState) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 50973a20bc..927b6b379c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -1,7 +1,10 @@ from collections import defaultdict +from typing import Any +from typing import cast from typing import Literal import numpy as np +from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs from langgraph.types import Command @@ -32,44 +35,37 @@ from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import retrieval_preprocessing from onyx.context.search.pipeline import search_postprocessing -from onyx.context.search.pipeline import SearchPipeline from onyx.llm.interfaces import LLM +from onyx.tools.tool_implementations.search.search_tool import ( + SEARCH_RESPONSE_SUMMARY_ID, +) -def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: - verified_documents = state["verified_documents"] - - # Rerank post retrieval and verification. First, create a search query - # then create the list of reranked sections - - question = state.get("question", state["subgraph_search_request"].query) - _search_query = retrieval_preprocessing( - search_request=SearchRequest(query=question), - user=None, - llm=state["subgraph_fast_llm"], - db_session=state["subgraph_db_session"], - ) +def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: + question = state.get("question") + llm: LLM = state["subgraph_fast_llm"] - reranked_documents = list( - search_postprocessing( - search_query=_search_query, - retrieved_sections=verified_documents, - llm=state["subgraph_fast_llm"], + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), ) - )[ - 0 - ] # only get the reranked szections, not the SectionRelevancePiece + ] + llm_response_list: list[str | list[str | dict[str, Any]]] = [] + for message in llm.stream( + prompt=msg, + ): + dispatch_custom_event( + "subqueries", + message.content, + ) + llm_response_list.append(message.content) - if AGENT_RERANKING_STATS: - fit_scores = get_fit_scores(verified_documents, reranked_documents) - else: - fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - return DocRerankingUpdate( - reranked_documents=[ - doc for doc in reranked_documents if type(doc) == InferenceSection - ][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS], - sub_question_retrieval_stats=fit_scores, + rewritten_queries = llm_response.split("--") + + return QueryExpansionUpdate( + expanded_queries=rewritten_queries, ) @@ -84,27 +80,31 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: expanded_retrieval_results: list[ExpandedRetrievalResult] retrieved_documents: list[InferenceSection] """ - - llm = state["subgraph_primary_llm"] - fast_llm = state["subgraph_fast_llm"] query_to_retrieve = state["query_to_retrieve"] + search_tool = state["subgraph_search_tool"] - search_results = SearchPipeline( - search_request=SearchRequest( - query=query_to_retrieve, - ), - user=None, - llm=llm, - fast_llm=fast_llm, - db_session=state["subgraph_db_session"], - ) + retrieved_docs: list[InferenceSection] = [] + for tool_response in search_tool.run(query=query_to_retrieve): + if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: + retrieved_docs = cast( + list[InferenceSection], tool_response.response.top_sections + ) + dispatch_custom_event( + "tool_response", + tool_response, + ) - retrieved_docs = search_results._get_sections()[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] + retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] + pre_rerank_docs = retrieved_docs + if search_tool.search_pipeline is not None: + pre_rerank_docs = ( + search_tool.search_pipeline._retrieved_sections or retrieved_docs + ) if AGENT_RETRIEVAL_STATS: fit_scores = get_fit_scores( + pre_rerank_docs, retrieved_docs, - search_results.reranked_sections[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS], ) else: fit_scores = None @@ -120,6 +120,30 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: ) +def verification_kickoff( + state: ExpandedRetrievalState, +) -> Command[Literal["doc_verification"]]: + documents = state["retrieved_documents"] + verification_question = state.get( + "question", state["subgraph_search_request"].query + ) + return Command( + update={}, + goto=[ + Send( + node="doc_verification", + arg=DocVerificationInput( + doc_to_verify=doc, + question=verification_question, + base_search=False, + **in_subgraph_extract_core_fields(state), + ), + ) + for doc in documents + ], + ) + + def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: """ Check whether the document is relevant for the original user question @@ -156,26 +180,40 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: ) -def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: - question = state.get("question") - llm: LLM = state["subgraph_fast_llm"] +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: + verified_documents = state["verified_documents"] - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), - ) - ] - llm_response_list = list( - llm.stream( - prompt=msg, - ) + # Rerank post retrieval and verification. First, create a search query + # then create the list of reranked sections + + question = state.get("question", state["subgraph_search_request"].query) + _search_query = retrieval_preprocessing( + search_request=SearchRequest(query=question), + user=None, + llm=state["subgraph_fast_llm"], + db_session=state["subgraph_db_session"], ) - llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - rewritten_queries = llm_response.split("--") + reranked_documents = list( + search_postprocessing( + search_query=_search_query, + retrieved_sections=verified_documents, + llm=state["subgraph_fast_llm"], + ) + )[ + 0 + ] # only get the reranked szections, not the SectionRelevancePiece - return QueryExpansionUpdate( - expanded_queries=rewritten_queries, + if AGENT_RERANKING_STATS: + fit_scores = get_fit_scores(verified_documents, reranked_documents) + else: + fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) + + return DocRerankingUpdate( + reranked_documents=[ + doc for doc in reranked_documents if type(doc) == InferenceSection + ][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS], + sub_question_retrieval_stats=fit_scores, ) @@ -266,27 +304,3 @@ def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalUpdate: sub_question_retrieval_stats=sub_question_retrieval_stats, ), ) - - -def verification_kickoff( - state: ExpandedRetrievalState, -) -> Command[Literal["doc_verification"]]: - documents = state["retrieved_documents"] - verification_question = state.get( - "question", state["subgraph_search_request"].query - ) - return Command( - update={}, - goto=[ - Send( - node="doc_verification", - arg=DocVerificationInput( - doc_to_verify=doc, - question=verification_question, - base_search=False, - **in_subgraph_extract_core_fields(state), - ), - ) - for doc in documents - ], - ) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 5fba915c56..2848284998 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -1,4 +1,9 @@ +from typing import Any +from typing import cast + +from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_content from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import QuestionAnswerResults @@ -35,10 +40,21 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: # Get the rewritten queries in a defined format model = state["fast_llm"] - response = model.invoke(msg) + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + for message in model.stream(msg): + dispatch_custom_event( + "decomp_qs", + message.content, + ) + streamed_tokens.append(message.content) + + response = merge_content(*streamed_tokens) - content = response.pretty_repr() - list_of_subquestions = clean_and_parse_list_string(content) + # this call should only return strings. Commenting out for efficiency + # assert [type(tok) == str for tok in streamed_tokens] + + # use no-op cast() instead of str() which runs code + list_of_subquestions = clean_and_parse_list_string(cast(str, response)) decomp_list: list[str] = [ sub_question["sub_question"].strip() for sub_question in list_of_subquestions @@ -192,8 +208,15 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: # Grader model = state["fast_llm"] - response = model.invoke(msg) - answer = response.pretty_repr() + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + for message in model.stream(msg): + dispatch_custom_event( + "main_answer", + message.content, + ) + streamed_tokens.append(message.content) + response = merge_content(*streamed_tokens) + answer = cast(str, response) initial_agent_stats = _calculate_initial_agent_stats( state["decomp_answer_results"], state["original_question_retrieval_stats"] @@ -201,7 +224,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") - print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") + print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n") if initial_agent_stats: print(initial_agent_stats.original_question) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index fe48e6a8b8..2b437c8fc6 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,39 +1,57 @@ import asyncio from collections.abc import AsyncIterable from collections.abc import Iterable +from typing import cast from langchain_core.runnables.schema import StreamEvent from langgraph.graph.state import CompiledStateGraph from onyx.agent_search.main.graph_builder import main_graph_builder from onyx.agent_search.main.states import MainInput -from onyx.chat.answer import AnswerStream -from onyx.chat.models import AnswerQuestionPossibleReturn +from onyx.chat.models import AnswerPacket +from onyx.chat.models import AnswerStream +from onyx.chat.models import AnswerStyleConfig +from onyx.chat.models import CitationConfig +from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import OnyxAnswerPiece +from onyx.chat.models import PromptConfig +from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE +from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT +from onyx.configs.constants import DEFAULT_PERSONA_ID +from onyx.context.search.enums import LLMEvaluationType +from onyx.context.search.models import RetrievalDetails from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager +from onyx.db.persona import get_persona_by_id from onyx.llm.interfaces import LLM from onyx.tools.models import ToolResponse +from onyx.tools.tool_constructor import SearchToolConfig +from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_runner import ToolCallKickoff def _parse_agent_event( event: StreamEvent, -) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None: +) -> AnswerPacket | None: """ Parse the event into a typed object. Return None if we are not interested in the event. """ event_type = event["event"] - if event_type == "on_chat_model_stream": - return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content) - elif event_type == "search_result": - # TODO: clean this up (weirdness to make mypy happy) - return ToolResponse( - id=str(event["data"].get("id", "error")), - response=event["data"].get("response", "error"), - ) + + if event_type == "on_custom_event": + # TODO: different AnswerStream types for different events + if event["name"] == "decomp_qs": + return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + elif event["name"] == "subqueries": + return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + elif event["name"] == "sub_answers": + return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + elif event["name"] == "main_answer": + return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + elif event["name"] == "tool_response": + return cast(ToolResponse, event["data"]) return None @@ -75,6 +93,7 @@ def _yield_async_to_sync() -> Iterable[StreamEvent]: def run_graph( compiled_graph: CompiledStateGraph, search_request: SearchRequest, + search_tool: SearchTool, primary_llm: LLM, fast_llm: LLM, ) -> AnswerStream: @@ -84,6 +103,7 @@ def run_graph( primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, + search_tool=search_tool, ) for event in _manage_async_event_streaming( compiled_graph=compiled_graph, graph_input=input @@ -94,12 +114,13 @@ def run_graph( def run_main_graph( search_request: SearchRequest, + search_tool: SearchTool, primary_llm: LLM, fast_llm: LLM, ) -> AnswerStream: graph = main_graph_builder() compiled_graph = graph.compile() - return run_graph(compiled_graph, search_request, primary_llm, fast_llm) + return run_graph(compiled_graph, search_request, search_tool, primary_llm, fast_llm) if __name__ == "__main__": @@ -110,8 +131,73 @@ def run_main_graph( compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="what can you do with onyx or danswer?", + query="what can you do with gitlab?", ) - for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): - print("a") - # print(output) + with get_session_context_manager() as db_session: + persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else MAX_CHUNKS_FED_TO_CHAT + ), + max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE, + ) + + answer_style_config = AnswerStyleConfig( + citation_config=CitationConfig( + # The docs retrieved by this flow are already relevance-filtered + all_docs_useful=True + ), + document_pruning_config=document_pruning_config, + structured_response_format=None, + ) + + search_tool_config = SearchToolConfig( + answer_style_config=answer_style_config, + document_pruning_config=document_pruning_config, + retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True + rerank_settings=None, # Can use this to change reranking model + selected_sections=None, + latest_query_files=None, + bypass_acl=False, + ) + + prompt_config = PromptConfig.from_model(persona.prompts[0]) + + search_tool = SearchTool( + db_session=db_session, + user=None, + persona=persona, + retrieval_options=search_tool_config.retrieval_options, + prompt_config=prompt_config, + llm=primary_llm, + fast_llm=fast_llm, + pruning_config=search_tool_config.document_pruning_config, + answer_style_config=search_tool_config.answer_style_config, + selected_sections=search_tool_config.selected_sections, + chunks_above=search_tool_config.chunks_above, + chunks_below=search_tool_config.chunks_below, + full_doc=search_tool_config.full_doc, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + rerank_settings=search_tool_config.rerank_settings, + bypass_acl=search_tool_config.bypass_acl, + ) + + with open("output.txt", "w") as f: + tool_responses = [] + for output in run_graph( + compiled_graph, search_request, search_tool, primary_llm, fast_llm + ): + if isinstance(output, OnyxAnswerPiece): + f.write(str(output.answer_piece) + "|") + elif isinstance(output, ToolCallKickoff): + pass + elif isinstance(output, ToolResponse): + tool_responses.append(output) + for tool_response in tool_responses: + f.write("tool response: " + str(tool_response.response) + "\n") diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 51836c228d..8a1c65638b 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -1,13 +1,14 @@ from collections.abc import Callable -from collections.abc import Iterator from uuid import uuid4 from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk from langchain_core.messages import ToolCall +from onyx.agent_search.run_graph import run_main_graph from onyx.chat.llm_response_handler import LLMResponseHandlerManager -from onyx.chat.models import AnswerQuestionPossibleReturn +from onyx.chat.models import AnswerPacket +from onyx.chat.models import AnswerStream from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece @@ -24,31 +25,27 @@ ) from onyx.chat.stream_processing.utils import map_document_id_order from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler +from onyx.context.search.models import SearchRequest from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage from onyx.natural_language_processing.utils import get_tokenizer from onyx.tools.force import ForceUseTool -from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import SearchTool -from onyx.tools.tool_runner import ToolCallKickoff from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.logger import setup_logger - logger = setup_logger() -AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] - - class Answer: def __init__( self, question: str, answer_style_config: AnswerStyleConfig, llm: LLM, + fast_llm: LLM, prompt_config: PromptConfig, force_use_tool: ForceUseTool, # must be the same length as `docs`. If None, all docs are considered "relevant" @@ -67,6 +64,8 @@ def __init__( return_contexts: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, + use_pro_search: bool = False, + search_request: SearchRequest | None = None, ) -> None: if single_message_history and message_history: raise ValueError( @@ -90,6 +89,7 @@ def __init__( self.prompt_config = prompt_config self.llm = llm + self.fast_llm = fast_llm self.llm_tokenizer = get_tokenizer( provider_type=llm.config.model_provider, model_name=llm.config.model_name, @@ -98,9 +98,7 @@ def __init__( self._final_prompt: list[BaseMessage] | None = None self._streamed_output: list[str] | None = None - self._processed_stream: ( - list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None - ) = None + self._processed_stream: (list[AnswerPacket] | None) = None self._return_contexts = return_contexts self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation @@ -113,6 +111,9 @@ def __init__( and not skip_explicit_tool_calling ) + self.use_pro_search = use_pro_search + self.pro_search_request = search_request + def _get_tools_list(self) -> list[Tool]: if not self.force_use_tool.force_use: return self.tools @@ -258,6 +259,25 @@ def processed_streamed_output(self) -> AnswerStream: yield from self._processed_stream return + if self.use_pro_search: + if self.pro_search_request is None: + raise ValueError("Search request must be provided for pro search") + search_tools = [tool for tool in self.tools if isinstance(tool, SearchTool)] + if len(search_tools) == 0: + raise ValueError("No search tool found") + elif len(search_tools) > 1: + # TODO: handle multiple search tools + raise ValueError("Multiple search tools found") + + search_tool = search_tools[0] + yield from run_main_graph( + search_request=self.pro_search_request, + primary_llm=self.llm, + fast_llm=self.fast_llm, + search_tool=search_tool, + ) + return + prompt_builder = AnswerPromptBuilder( user_message=default_build_user_message( user_query=self.question, diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 44973446f5..91a5689d75 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -327,3 +327,23 @@ def from_model( | ToolCallFinalResult | StreamStopInfo ) + + +class SubQuery(BaseModel): + sub_query: str + + +class SubAnswer(BaseModel): + sub_answer: str + + +class SubQuestion(BaseModel): + question_id: str + sub_question: str + + +ProSearchPacket = SubQuestion | SubAnswer | SubQuery + +AnswerPacket = AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse + +AnswerStream = Iterator[AnswerPacket] diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 1ffc2e8bc1..96820b7c3d 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session -from onyx.agent_search.run_graph import run_main_graph from onyx.chat.answer import Answer from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona @@ -686,6 +685,33 @@ def stream_chat_message_objects( for tool_list in tool_dict.values(): tools.extend(tool_list) + search_request = None + if new_msg_req.use_pro_search: + search_request = SearchRequest( + query=final_msg.message, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + human_selected_filters=( + retrieval_options.filters if retrieval_options else None + ), + persona=persona, + offset=(retrieval_options.offset if retrieval_options else None), + limit=retrieval_options.limit if retrieval_options else None, + rerank_settings=new_msg_req.rerank_settings, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + enable_auto_detect_filters=( + retrieval_options.enable_auto_detect_filters + if retrieval_options + else None + ), + ) + # TODO: add previous messages, answer style config, tools, etc. + # LLM prompt building, response capturing, etc. answer = Answer( is_connected=is_connected, @@ -705,11 +731,14 @@ def stream_chat_message_objects( ) ) ), + fast_llm=fast_llm, message_history=[ PreviousMessage.from_chat_message(msg, files) for msg in history_msgs ], tools=tools, force_use_tool=_get_force_search_settings(new_msg_req, tools), + search_request=search_request, + use_pro_search=new_msg_req.use_pro_search, ) reference_db_search_docs = None @@ -719,39 +748,7 @@ def stream_chat_message_objects( dropped_indices = None tool_result = None - if not new_msg_req.use_pro_search: - answer_stream = answer.processed_streamed_output - else: - search_request = SearchRequest( - query=final_msg.message, - evaluation_type=( - LLMEvaluationType.BASIC - if persona.llm_relevance_filter - else LLMEvaluationType.SKIP - ), - human_selected_filters=( - retrieval_options.filters if retrieval_options else None - ), - persona=persona, - offset=(retrieval_options.offset if retrieval_options else None), - limit=retrieval_options.limit if retrieval_options else None, - rerank_settings=new_msg_req.rerank_settings, - chunks_above=new_msg_req.chunks_above, - chunks_below=new_msg_req.chunks_below, - full_doc=new_msg_req.full_doc, - enable_auto_detect_filters=( - retrieval_options.enable_auto_detect_filters - if retrieval_options - else None - ), - ) - answer_stream = run_main_graph( - search_request=search_request, - primary_llm=llm, - fast_llm=fast_llm, - ) - - for packet in answer_stream: + for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: ( diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 368111ca46..7dcfde4d76 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -117,6 +117,8 @@ def __init__( self.fast_llm = fast_llm self.evaluation_type = evaluation_type + self.search_pipeline: SearchPipeline | None = None + self.selected_sections = selected_sections self.full_doc = full_doc @@ -330,6 +332,8 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: ), ) + self.search_pipeline = search_pipeline + yield ToolResponse( id=SEARCH_DOC_CONTENT_ID, response=OnyxContexts( From 60207589d2a3003ac1e57fe6218a6e9481f4e221 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 2 Jan 2025 15:52:34 -0800 Subject: [PATCH 41/78] remove need for refined subgraph --- backend/onyx/agent_search/main/edges.py | 4 +- .../onyx/agent_search/main/graph_builder.py | 56 ++++++++++++++++--- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index bd62b5e3db..ac57faeaa7 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -54,9 +54,9 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: # Define the function that determines whether to continue or not def continue_to_refined_answer_or_end( state: RequireRefinedAnswerUpdate, -) -> Literal["refined_answer_subgraph", "END"]: +) -> Literal["follow_up_decompose", "END"]: if state["require_refined_answer"]: - return "refined_answer_subgraph" + return "follow_up_decompose" else: return END diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index dcd1cf8647..a986e420a2 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -2,6 +2,9 @@ from langgraph.graph import START from langgraph.graph import StateGraph +from onyx.agent_search.answer_follow_up_question.graph_builder import ( + answer_follow_up_query_graph_builder, +) from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder from onyx.agent_search.base_raw_search.graph_builder import ( base_raw_search_graph_builder, @@ -18,9 +21,9 @@ from onyx.agent_search.main.nodes import refined_answer_decision from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState -from onyx.agent_search.refined_answers.graph_builder import ( - refined_answers_graph_builder, -) +from onyx.agent_search.refined_answers.edges import parallelize_follow_up_answer_queries +from onyx.agent_search.refined_answers.nodes import follow_up_decompose +from onyx.agent_search.refined_answers.nodes import ingest_follow_up_answers # from onyx.agent_search.main.nodes import check_refined_answer @@ -321,10 +324,26 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: action=base_raw_search_subgraph, ) - refined_answer_subgraph = refined_answers_graph_builder().compile() + # refined_answer_subgraph = refined_answers_graph_builder().compile() + # graph.add_node( + # node="refined_answer_subgraph", + # action=refined_answer_subgraph, + # ) + + graph.add_node( + node="follow_up_decompose", + action=follow_up_decompose, + ) + + answer_follow_up_question = answer_follow_up_query_graph_builder().compile() + graph.add_node( + node="answer_follow_up_question", + action=answer_follow_up_question, + ) + graph.add_node( - node="refined_answer_subgraph", - action=refined_answer_subgraph, + node="ingest_follow_up_answers", + action=ingest_follow_up_answers, ) graph.add_node( @@ -415,13 +434,34 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_conditional_edges( source="refined_answer_decision", path=continue_to_refined_answer_or_end, - path_map=["refined_answer_subgraph", END], + path_map=["follow_up_decompose", END], + ) + + graph.add_conditional_edges( + source="follow_up_decompose", + path=parallelize_follow_up_answer_queries, + path_map=["answer_follow_up_question"], + ) + graph.add_edge( + start_key="answer_follow_up_question", + end_key="ingest_follow_up_answers", ) graph.add_edge( - start_key="refined_answer_subgraph", + start_key="ingest_follow_up_answers", end_key="generate_refined_answer", ) + + # graph.add_conditional_edges( + # source="refined_answer_decision", + # path=continue_to_refined_answer_or_end, + # path_map=["refined_answer_subgraph", END], + # ) + + # graph.add_edge( + # start_key="refined_answer_subgraph", + # end_key="generate_refined_answer", + # ) graph.add_edge( start_key="generate_refined_answer", end_key=END, From a2067231911b6efb9ebb57b08155769104b6a53f Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 2 Jan 2025 16:22:45 -0800 Subject: [PATCH 42/78] moved object dependencies out of refined subgraph --- backend/onyx/agent_search/main/edges.py | 25 +++++ .../onyx/agent_search/main/graph_builder.py | 6 +- backend/onyx/agent_search/main/models.py | 7 ++ backend/onyx/agent_search/main/nodes.py | 94 +++++++++++++++++- backend/onyx/agent_search/main/states.py | 14 ++- .../agent_search/refined_answers/edges.py | 33 ------- .../refined_answers/graph_builder.py | 14 +-- .../agent_search/refined_answers/models.py | 8 -- .../agent_search/refined_answers/nodes.py | 98 ------------------- .../agent_search/refined_answers/states.py | 14 --- 10 files changed, 147 insertions(+), 166 deletions(-) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index ac57faeaa7..8ff525c464 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -61,6 +61,31 @@ def continue_to_refined_answer_or_end( return END +def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashable]: + if len(state["follow_up_sub_questions"]) > 0: + return [ + Send( + "answer_follow_up_question", + AnswerQuestionInput( + **extract_core_fields_for_subgraph(state), + question=question_data.sub_question, + question_nr=question_nr, + ), + ) + for question_nr, question_data in state["follow_up_sub_questions"].items() + ] + + else: + return [ + Send( + "ingest_follow_up_answers", + AnswerQuestionOutput( + answer_results=[], + ), + ) + ] + + # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: # # Routes re-written queries to the (parallel) retrieval steps # # Notice the 'Send()' API that takes care of the parallelization diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index a986e420a2..d8c5e5e4ab 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -11,19 +11,19 @@ ) from onyx.agent_search.main.edges import continue_to_refined_answer_or_end from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries +from onyx.agent_search.main.edges import parallelize_follow_up_answer_queries from onyx.agent_search.main.nodes import entity_term_extraction +from onyx.agent_search.main.nodes import follow_up_decompose from onyx.agent_search.main.nodes import generate_initial_answer from onyx.agent_search.main.nodes import generate_refined_answer from onyx.agent_search.main.nodes import ingest_answers +from onyx.agent_search.main.nodes import ingest_follow_up_answers from onyx.agent_search.main.nodes import ingest_initial_retrieval from onyx.agent_search.main.nodes import initial_answer_quality_check from onyx.agent_search.main.nodes import main_decomp_base from onyx.agent_search.main.nodes import refined_answer_decision from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState -from onyx.agent_search.refined_answers.edges import parallelize_follow_up_answer_queries -from onyx.agent_search.refined_answers.nodes import follow_up_decompose -from onyx.agent_search.refined_answers.nodes import ingest_follow_up_answers # from onyx.agent_search.main.nodes import check_refined_answer diff --git a/backend/onyx/agent_search/main/models.py b/backend/onyx/agent_search/main/models.py index 0f011af3dd..f23611011b 100644 --- a/backend/onyx/agent_search/main/models.py +++ b/backend/onyx/agent_search/main/models.py @@ -25,3 +25,10 @@ class EntityRelationshipTermExtraction(BaseModel): entities: list[Entity] relationships: list[Relationship] terms: list[Term] + + +class FollowUpSubQuestion(BaseModel): + sub_question: str + verified: bool + answered: bool + answer: str diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index c4a624dd24..b758e828a0 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -9,22 +9,28 @@ from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput from onyx.agent_search.main.models import Entity from onyx.agent_search.main.models import EntityRelationshipTermExtraction +from onyx.agent_search.main.models import FollowUpSubQuestion from onyx.agent_search.main.models import Relationship from onyx.agent_search.main.models import Term from onyx.agent_search.main.states import BaseDecompUpdate from onyx.agent_search.main.states import DecompAnswersUpdate from onyx.agent_search.main.states import EntityTermExtractionUpdate from onyx.agent_search.main.states import ExpandedRetrievalUpdate +from onyx.agent_search.main.states import FollowUpDecompAnswersUpdate +from onyx.agent_search.main.states import FollowUpSubQuestionsUpdate from onyx.agent_search.main.states import InitialAnswerBASEUpdate from onyx.agent_search.main.states import InitialAnswerQualityUpdate from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState +from onyx.agent_search.main.states import RefinedAnswerInput +from onyx.agent_search.main.states import RefinedAnswerOutput from onyx.agent_search.main.states import RefinedAnswerUpdate from onyx.agent_search.main.states import RequireRefinedAnswerUpdate from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT from onyx.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from onyx.agent_search.shared_graph_utils.prompts import ( INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, @@ -37,6 +43,7 @@ from onyx.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string from onyx.agent_search.shared_graph_utils.utils import format_docs +from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction def main_decomp_base(state: MainState) -> BaseDecompUpdate: @@ -548,7 +555,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: initial_support_boost_factor = state[ "initial_agent_stats" ].agent_effectiveness.get("support_ratio", "--") - initial_verified_docs = state["initial_agent_stats"].original_question.get( + num_initial_verified_docs = state["initial_agent_stats"].original_question.get( "num_verified_documents", "--" ) initial_verified_docs_avg_score = state[ @@ -561,7 +568,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: print("INITIAL AGENT STATS") print(f"Document Boost Factor: {initial_doc_boost_factor}") print(f"Support Boost Factor: {initial_support_boost_factor}") - print(f"Originally Verified Docs: {initial_verified_docs}") + print(f"Originally Verified Docs: {num_initial_verified_docs}") print(f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}") print(f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}") if refined_agent_stats: @@ -581,6 +588,89 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: ) +def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: + """ """ + + question = state["search_request"].query + base_answer = state["initial_answer"] + + # get the entity term extraction dict and properly format it + entity_retlation_term_extractions = state["entity_retlation_term_extractions"] + + entity_term_extraction_str = format_entity_term_extraction( + entity_retlation_term_extractions + ) + + initial_question_answers = state["decomp_answer_results"] + + addressed_question_list = [ + x.question for x in initial_question_answers if "yes" in x.quality.lower() + ] + + failed_question_list = [ + x.question for x in initial_question_answers if "no" in x.quality.lower() + ] + + msg = [ + HumanMessage( + content=DEEP_DECOMPOSE_PROMPT.format( + question=question, + entity_term_extraction_str=entity_term_extraction_str, + base_answer=base_answer, + answered_sub_questions="\n - ".join(addressed_question_list), + failed_sub_questions="\n - ".join(failed_question_list), + ), + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + if isinstance(response.content, str): + cleaned_response = re.sub(r"```json\n|\n```", "", response.content) + parsed_response = json.loads(cleaned_response) + else: + raise ValueError("LLM response is not a string") + + follow_up_sub_question_dict = {} + for sub_question_nr, sub_question_dict in enumerate( + parsed_response["sub_questions"] + ): + follow_up_sub_question = FollowUpSubQuestion( + sub_question=sub_question_dict["sub_question"], + verified=False, + answered=False, + answer="", + ) + + follow_up_sub_question_dict[sub_question_nr] = follow_up_sub_question + + return FollowUpSubQuestionsUpdate( + follow_up_sub_questions=follow_up_sub_question_dict + ) + + +def ingest_follow_up_answers( + state: AnswerQuestionOutput, +) -> FollowUpDecompAnswersUpdate: + documents = [] + answer_results = state.get("answer_results", []) + for answer_result in answer_results: + documents.extend(answer_result.documents) + return FollowUpDecompAnswersUpdate( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + follow_up_documents=dedup_inference_sections(documents, []), + follow_up_decomp_answer_results=answer_results, + ) + + +def dummy_node(state: RefinedAnswerInput) -> RefinedAnswerOutput: + print("---DUMMY NODE---") + return {"dummy_output": "this is a dummy output"} + + # def check_refined_answer(state: MainState) -> RefinedAnswerUpdate: # print("---CHECK REFINED ANSWER---") diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 9c90448c47..dabc195745 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -7,7 +7,7 @@ from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.main.models import EntityRelationshipTermExtraction -from onyx.agent_search.refined_answers.models import FollowUpSubQuestion +from onyx.agent_search.main.models import FollowUpSubQuestion from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats @@ -122,3 +122,15 @@ class MainOutput(TypedDict): initial_agent_stats: dict generated_sub_questions: list[str] require_refined_answer: bool + + +class RefinedAnswerInput(MainState): + pass + + +class RefinedAnswerOutput(TypedDict): + dummy_output: str + + +class RefinedAnswerState(RefinedAnswerInput, RefinedAnswerOutput): + pass diff --git a/backend/onyx/agent_search/refined_answers/edges.py b/backend/onyx/agent_search/refined_answers/edges.py index 2fa657cdb5..e69de29bb2 100644 --- a/backend/onyx/agent_search/refined_answers/edges.py +++ b/backend/onyx/agent_search/refined_answers/edges.py @@ -1,33 +0,0 @@ -from collections.abc import Hashable - -from langgraph.types import Send - -from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.core_state import extract_core_fields_for_subgraph -from onyx.agent_search.main.states import MainState - - -def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashable]: - if len(state["follow_up_sub_questions"]) > 0: - return [ - Send( - "answer_follow_up_question", - AnswerQuestionInput( - **extract_core_fields_for_subgraph(state), - question=question_data.sub_question, - question_nr=question_nr, - ), - ) - for question_nr, question_data in state["follow_up_sub_questions"].items() - ] - - else: - return [ - Send( - "ingest_follow_up_answers", - AnswerQuestionOutput( - answer_results=[], - ), - ) - ] diff --git a/backend/onyx/agent_search/refined_answers/graph_builder.py b/backend/onyx/agent_search/refined_answers/graph_builder.py index e628caef3a..7a897bba44 100644 --- a/backend/onyx/agent_search/refined_answers/graph_builder.py +++ b/backend/onyx/agent_search/refined_answers/graph_builder.py @@ -5,13 +5,13 @@ from onyx.agent_search.answer_follow_up_question.graph_builder import ( answer_follow_up_query_graph_builder, ) -from onyx.agent_search.refined_answers.edges import parallelize_follow_up_answer_queries -from onyx.agent_search.refined_answers.nodes import dummy_node -from onyx.agent_search.refined_answers.nodes import follow_up_decompose -from onyx.agent_search.refined_answers.nodes import ingest_follow_up_answers -from onyx.agent_search.refined_answers.states import RefinedAnswerInput -from onyx.agent_search.refined_answers.states import RefinedAnswerOutput -from onyx.agent_search.refined_answers.states import RefinedAnswerState +from onyx.agent_search.main.edges import parallelize_follow_up_answer_queries +from onyx.agent_search.main.nodes import dummy_node +from onyx.agent_search.main.nodes import follow_up_decompose +from onyx.agent_search.main.nodes import ingest_follow_up_answers +from onyx.agent_search.main.states import RefinedAnswerInput +from onyx.agent_search.main.states import RefinedAnswerOutput +from onyx.agent_search.main.states import RefinedAnswerState def refined_answers_graph_builder() -> StateGraph: diff --git a/backend/onyx/agent_search/refined_answers/models.py b/backend/onyx/agent_search/refined_answers/models.py index 23347f6398..e69de29bb2 100644 --- a/backend/onyx/agent_search/refined_answers/models.py +++ b/backend/onyx/agent_search/refined_answers/models.py @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class FollowUpSubQuestion(BaseModel): - sub_question: str - verified: bool - answered: bool - answer: str diff --git a/backend/onyx/agent_search/refined_answers/nodes.py b/backend/onyx/agent_search/refined_answers/nodes.py index e9e5afa2fc..15a53d052a 100644 --- a/backend/onyx/agent_search/refined_answers/nodes.py +++ b/backend/onyx/agent_search/refined_answers/nodes.py @@ -1,104 +1,6 @@ -import json -import re - -from langchain_core.messages import HumanMessage - -from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.main.states import FollowUpAnswerQuestionOutput -from onyx.agent_search.main.states import FollowUpDecompAnswersUpdate -from onyx.agent_search.main.states import FollowUpSubQuestionsUpdate -from onyx.agent_search.main.states import MainState -from onyx.agent_search.refined_answers.models import FollowUpSubQuestion -from onyx.agent_search.refined_answers.states import RefinedAnswerInput -from onyx.agent_search.refined_answers.states import RefinedAnswerOutput -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT -from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction - - -def dummy_node(state: RefinedAnswerInput) -> RefinedAnswerOutput: - print("---DUMMY NODE---") - return {"dummy_output": "this is a dummy output"} - - -def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: - """ """ - - question = state["search_request"].query - base_answer = state["initial_answer"] - - # get the entity term extraction dict and properly format it - entity_retlation_term_extractions = state["entity_retlation_term_extractions"] - - entity_term_extraction_str = format_entity_term_extraction( - entity_retlation_term_extractions - ) - - initial_question_answers = state["decomp_answer_results"] - - addressed_question_list = [ - x.question for x in initial_question_answers if "yes" in x.quality.lower() - ] - - failed_question_list = [ - x.question for x in initial_question_answers if "no" in x.quality.lower() - ] - - msg = [ - HumanMessage( - content=DEEP_DECOMPOSE_PROMPT.format( - question=question, - entity_term_extraction_str=entity_term_extraction_str, - base_answer=base_answer, - answered_sub_questions="\n - ".join(addressed_question_list), - failed_sub_questions="\n - ".join(failed_question_list), - ), - ) - ] - - # Grader - model = state["fast_llm"] - response = model.invoke(msg) - - if isinstance(response.content, str): - cleaned_response = re.sub(r"```json\n|\n```", "", response.content) - parsed_response = json.loads(cleaned_response) - else: - raise ValueError("LLM response is not a string") - - follow_up_sub_question_dict = {} - for sub_question_nr, sub_question_dict in enumerate( - parsed_response["sub_questions"] - ): - follow_up_sub_question = FollowUpSubQuestion( - sub_question=sub_question_dict["sub_question"], - verified=False, - answered=False, - answer="", - ) - - follow_up_sub_question_dict[sub_question_nr] = follow_up_sub_question - - return FollowUpSubQuestionsUpdate( - follow_up_sub_questions=follow_up_sub_question_dict - ) - - -def ingest_follow_up_answers( - state: AnswerQuestionOutput, -) -> FollowUpDecompAnswersUpdate: - documents = [] - answer_results = state.get("answer_results", []) - for answer_result in answer_results: - documents.extend(answer_result.documents) - return FollowUpDecompAnswersUpdate( - # Deduping is done by the documents operator for the main graph - # so we might not need to dedup here - follow_up_documents=dedup_inference_sections(documents, []), - follow_up_decomp_answer_results=answer_results, - ) def format_follow_up_answer(state: AnswerQuestionState) -> FollowUpAnswerQuestionOutput: diff --git a/backend/onyx/agent_search/refined_answers/states.py b/backend/onyx/agent_search/refined_answers/states.py index 2f393355b3..5b62335a72 100644 --- a/backend/onyx/agent_search/refined_answers/states.py +++ b/backend/onyx/agent_search/refined_answers/states.py @@ -1,19 +1,5 @@ from typing import TypedDict -from onyx.agent_search.main.states import MainState - - -class RefinedAnswerInput(MainState): - pass - - -class RefinedAnswerOutput(TypedDict): - dummy_output: str - class FollowUpSubQuestionsUpdate(TypedDict): follow_up_sub_question_dict: dict[str, dict[str, str]] - - -class RefinedAnswerState(RefinedAnswerInput, RefinedAnswerOutput): - pass From 13a5a86dec5b5f1e89a7f7cf1a26322f8a05d230 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 2 Jan 2025 16:54:52 -0800 Subject: [PATCH 43/78] removed refined_search subgraph --- .../agent_search/refined_answers/edges.py | 0 .../refined_answers/graph_builder.py | 114 ------------------ .../agent_search/refined_answers/models.py | 0 .../agent_search/refined_answers/nodes.py | 31 ----- .../agent_search/refined_answers/states.py | 5 - 5 files changed, 150 deletions(-) delete mode 100644 backend/onyx/agent_search/refined_answers/edges.py delete mode 100644 backend/onyx/agent_search/refined_answers/graph_builder.py delete mode 100644 backend/onyx/agent_search/refined_answers/models.py delete mode 100644 backend/onyx/agent_search/refined_answers/nodes.py delete mode 100644 backend/onyx/agent_search/refined_answers/states.py diff --git a/backend/onyx/agent_search/refined_answers/edges.py b/backend/onyx/agent_search/refined_answers/edges.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/backend/onyx/agent_search/refined_answers/graph_builder.py b/backend/onyx/agent_search/refined_answers/graph_builder.py deleted file mode 100644 index 7a897bba44..0000000000 --- a/backend/onyx/agent_search/refined_answers/graph_builder.py +++ /dev/null @@ -1,114 +0,0 @@ -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph - -from onyx.agent_search.answer_follow_up_question.graph_builder import ( - answer_follow_up_query_graph_builder, -) -from onyx.agent_search.main.edges import parallelize_follow_up_answer_queries -from onyx.agent_search.main.nodes import dummy_node -from onyx.agent_search.main.nodes import follow_up_decompose -from onyx.agent_search.main.nodes import ingest_follow_up_answers -from onyx.agent_search.main.states import RefinedAnswerInput -from onyx.agent_search.main.states import RefinedAnswerOutput -from onyx.agent_search.main.states import RefinedAnswerState - - -def refined_answers_graph_builder() -> StateGraph: - graph = StateGraph( - state_schema=RefinedAnswerState, - input=RefinedAnswerInput, - output=RefinedAnswerOutput, - ) - - ### Add nodes ### - - graph.add_node( - node="dummy_node", - action=dummy_node, - ) - - graph.add_node( - node="follow_up_decompose", - action=follow_up_decompose, - ) - - answer_follow_up_question = answer_follow_up_query_graph_builder().compile() - graph.add_node( - node="answer_follow_up_question", - action=answer_follow_up_question, - ) - - graph.add_node( - node="ingest_follow_up_answers", - action=ingest_follow_up_answers, - ) - - # graph.add_node( - # node="format_follow_up_answer", - # action=format_follow_up_answer, - # ) - - ### Add edges ### - - graph.add_edge(start_key=START, end_key="dummy_node") - - graph.add_edge( - start_key="dummy_node", - end_key="follow_up_decompose", - ) - - graph.add_conditional_edges( - source="follow_up_decompose", - path=parallelize_follow_up_answer_queries, - path_map=["answer_follow_up_question"], - ) - graph.add_edge( - start_key="answer_follow_up_question", - end_key="ingest_follow_up_answers", - ) - - # graph.add_conditional_edges( - # start_key="answer_follow_up_question", - # end_key="ingest_follow_up_answers", - # ) - - # graph.add_conditional_edges( - # start_key="ingest_follow_up_answers", - # end_key="format_follow_up_answer", - # ) - - # graph.add_edge( - # start_key="format_follow_up_answer", - # end_key="generate_refined_answer", - # ) - - # graph.add_edge( - # start_key="generate_refined_answer", - # end_key="refined_answer_quality_check", - # ) - - # graph.add_edge( - # start_key="refined_answer_quality_check", - # end_key=END, - # ) - - # graph.add_edge( - # start_key="ingest_follow_up_answers", - # end_key="format_follow_up_answer", - # ) - # graph.add_edge( - # start_key="format_follow_up_answer", - # end_key=END, - # ) - - graph.add_edge( - start_key="ingest_follow_up_answers", - end_key=END, - ) - - return graph - - -if __name__ == "__main__": - pass diff --git a/backend/onyx/agent_search/refined_answers/models.py b/backend/onyx/agent_search/refined_answers/models.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/backend/onyx/agent_search/refined_answers/nodes.py b/backend/onyx/agent_search/refined_answers/nodes.py deleted file mode 100644 index 15a53d052a..0000000000 --- a/backend/onyx/agent_search/refined_answers/nodes.py +++ /dev/null @@ -1,31 +0,0 @@ -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QuestionAnswerResults -from onyx.agent_search.main.states import FollowUpAnswerQuestionOutput - - -def format_follow_up_answer(state: AnswerQuestionState) -> FollowUpAnswerQuestionOutput: - return FollowUpAnswerQuestionOutput( - follow_up_answer_results=[ - QuestionAnswerResults( - question=state["question"], - quality=state.get("answer_quality", "No"), - answer=state["answer"], - # expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - sub_question_retrieval_stats=state["sub_question_retrieval_stats"], - ) - ], - ) - - -# def ingest_follow_up_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: -# documents = [] -# answer_results = state.get("answer_results", []) -# for answer_result in answer_results: -# documents.extend(answer_result.documents) -# return DecompAnswersUpdate( -# # Deduping is done by the documents operator for the main graph -# # so we might not need to dedup here -# documents=dedup_inference_sections(documents, []), -# decomp_answer_results=answer_results, -# ) diff --git a/backend/onyx/agent_search/refined_answers/states.py b/backend/onyx/agent_search/refined_answers/states.py deleted file mode 100644 index 5b62335a72..0000000000 --- a/backend/onyx/agent_search/refined_answers/states.py +++ /dev/null @@ -1,5 +0,0 @@ -from typing import TypedDict - - -class FollowUpSubQuestionsUpdate(TypedDict): - follow_up_sub_question_dict: dict[str, dict[str, str]] From 72e56aa4ca582b295d648b527fcd0992475e7315 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 3 Jan 2025 07:49:41 -0800 Subject: [PATCH 44/78] assistant persona - untested --- .../nodes/answer_generation.py | 12 ++ backend/onyx/agent_search/main/nodes.py | 105 ++++++++++-------- backend/onyx/agent_search/run_graph.py | 4 + .../shared_graph_utils/prompts.py | 97 +++++++++++++--- .../agent_search/shared_graph_utils/utils.py | 10 ++ 5 files changed, 167 insertions(+), 61 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 0403583567..d5e4ff5d67 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -3,13 +3,24 @@ from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.answer_question.states import QAGenerationUpdate +from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT +from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs +from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question = state["question"] docs = state["documents"] + persona_prompt = get_persona_prompt(state["subgraph_search_request"].persona) + + if len(persona_prompt) > 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) print(f"Number of verified retrieval docs: {len(docs)}") @@ -19,6 +30,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question=question, context=format_docs(docs), original_question=state["subgraph_search_request"].query, + persona_specification=persona_specification, ) ) ] diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index b758e828a0..4c99721e91 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -30,6 +30,8 @@ from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT +from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT from onyx.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from onyx.agent_search.shared_graph_utils.prompts import ( @@ -41,13 +43,19 @@ INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, ) from onyx.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import ( + REVISED_RAG_PROMPT_NO_SUB_QUESTIONS, +) +from onyx.agent_search.shared_graph_utils.prompts import SUB_QUESTION_ANSWER_TEMPLATE from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction +from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt def main_decomp_base(state: MainState) -> BaseDecompUpdate: question = state["search_request"].query + get_persona_prompt(state["search_request"].persona) msg = [ HumanMessage( @@ -158,8 +166,10 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print("---GENERATE INITIAL---") question = state["search_request"].query + persona_prompt = get_persona_prompt(state["search_request"].persona) sub_question_docs = state["documents"] all_original_question_documents = state["all_original_question_documents"] + relevant_docs = dedup_inference_sections( sub_question_docs, all_original_question_documents ) @@ -174,9 +184,6 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: good_qa_list: list[str] = [] decomp_questions = [] - _SUB_QUESTION_ANSWER_TEMPLATE = """ - Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n - """ for decomp_answer_result in decomp_answer_results: decomp_questions.append(decomp_answer_result.question) if ( @@ -185,7 +192,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: and decomp_answer_result.answer != "I don't know" ): good_qa_list.append( - _SUB_QUESTION_ANSWER_TEMPLATE.format( + SUB_QUESTION_ANSWER_TEMPLATE.format( sub_question=decomp_answer_result.question, sub_answer=decomp_answer_result.answer, ) @@ -193,27 +200,32 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + # Determine which persona-specification prompt to use + + if len(persona_prompt) > 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) + + # Determine which base prompt to use given the sub-question information if len(good_qa_list) > 0: - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT.format( - question=question, - answered_sub_questions=sub_question_answer_str, - relevant_docs=format_docs(relevant_docs), - ) - ) - ] + base_prompt = INITIAL_RAG_PROMPT else: - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format( - question=question, - relevant_docs=format_docs(relevant_docs), - ) + base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS + + msg = [ + HumanMessage( + content=base_prompt.format( + question=question, + answered_sub_questions=sub_question_answer_str, + relevant_docs=format_docs(relevant_docs), + persona_specification=persona_specification, ) - ] + ) + ] - # Grader model = state["fast_llm"] response = model.invoke(msg) answer = response.pretty_repr() @@ -430,6 +442,9 @@ def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: print("---GENERATE REFINED ANSWER---") + question = state["search_request"].query + persona_prompt = get_persona_prompt(state["search_request"].persona) + initial_documents = state["documents"] revised_documents = state["follow_up_documents"] @@ -442,18 +457,12 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: else: revision_doc_effectiveness = 10.0 - question = state["search_request"].query - decomp_answer_results = state["decomp_answer_results"] revised_answer_results = state["follow_up_decomp_answer_results"] good_qa_list: list[str] = [] decomp_questions = [] - _SUB_QUESTION_ANSWER_TEMPLATE = """ - Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n - """ - initial_good_sub_questions: list[str] = [] new_revised_good_sub_questions: list[str] = [] @@ -466,7 +475,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: and decomp_answer_result.answer != "I don't know" ): good_qa_list.append( - _SUB_QUESTION_ANSWER_TEMPLATE.format( + SUB_QUESTION_ANSWER_TEMPLATE.format( sub_question=decomp_answer_result.question, sub_answer=decomp_answer_result.answer, ) @@ -491,26 +500,32 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: initial_answer = state["initial_answer"] + # Determine which persona-specification prompt to use + + if len(persona_prompt) > 0: + persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT + else: + persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( + persona_prompt=persona_prompt + ) + + # Determine which base prompt to use given the sub-question information if len(good_qa_list) > 0: - msg = [ - HumanMessage( - content=REVISED_RAG_PROMPT.format( - question=question, - answered_sub_questions=sub_question_answer_str, - relevant_docs=format_docs(combined_documents), - initial_answer=initial_answer, - ) - ) - ] + base_prompt = REVISED_RAG_PROMPT else: - msg = [ - HumanMessage( - content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format( - question=question, - relevant_docs=format_docs(combined_documents), - ) + base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS + + msg = [ + HumanMessage( + content=base_prompt.format( + question=question, + answered_sub_questions=sub_question_answer_str, + relevant_docs=format_docs(combined_documents), + initial_answer=initial_answer, + persona_specification=persona_specification, ) - ] + ) + ] # Grader model = state["fast_llm"] diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 5078a7c8b4..246377560a 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -92,6 +92,10 @@ def run_graph( fast_llm: LLM, ) -> AnswerStream: with get_session_context_manager() as db_session: + from onyx.db.persona import get_persona_by_id + + search_request.persona = get_persona_by_id(1, None, db_session) + input = MainInput( search_request=search_request, primary_llm=primary_llm, diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index d45c36b1ec..1beb961b98 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -18,7 +18,9 @@ \n ------- \n Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """ +# The prompt is only used if there is no persona prompt, so the placeholder is '' BASE_RAG_PROMPT = """ \n + {persona_prompt} You are an assistant for question-answering tasks. Use the context provided below - and only the provided context - to answer the given question. (Note that the answer is in service of anserwing a broader question, given below as 'motivation'.) @@ -448,8 +450,27 @@ Answer:""" +### ANSWER GENERATION PROMPTS + +# Persona specification +ASSISTANT_SYSTEM_PROMPT_DEFAULT = """ +You are an assistant for question-answering tasks.""" + +ASSISTANT_SYSTEM_PROMPT_PERSONA = """ +You are an assistant for question-answering tasks. Here is more information about you: +\n ------- \n +{persona_prompt} +\n ------- \n +""" + +SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + INITIAL_RAG_PROMPT = """ \n -You are an assistant for question-answering tasks. Use the information provided below - and only the +{persona_specification} + +Use the information provided below - and only the provided information - to answer the provided question. The information provided below consists of: @@ -485,8 +506,38 @@ \n--\n\n Answer:""" -REVISED_RAG_PROMPT = """ \n -You are an assistant for question-answering tasks. Use the information provided below - and only the +# sub_question_answer_str is empty +INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """{sub_question_answer_str} +{persona_specification} +Use the information provided below +- and only the provided information - to answer the provided question. +The information provided below consists of a number of documents that were deemed relevant for the question. + +IMPORTANT RULES: + - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is irrelevant, just say "I don't know". + - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + +Again, you should be sure that the answer is supported by the information provided! + +Try to keep your answer concise. + +Here are is the relevant context information: +\n-------\n +{relevant_docs} +\n-------\n + +And here is the question I want you to answer based on the context above +\n--\n +{question} +\n--\n + +Answer:""" + +REVISED_RAG_PROMPT = """\n +{persona_specification} +Use the information provided below - and only the provided information - to answer the provided question. The information provided below consists of: @@ -529,33 +580,47 @@ \n--\n\n Answer:""" -INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """ -You are an assistant for question-answering tasks. Use the information provided below -- and only the provided information - to answer the provided question. -The information provided below consists of a number of documents that were deemed relevant for the question. +# sub_question_answer_str is empty +REVISED_RAG_PROMPT_NO_SUB_QUESTIONS = """{sub_question_answer_str}\n +{persona_specification} +Use the information provided below - and only the +provided information - to answer the provided question. + +The information provided below consists of: + 1) an initial answer that was given but found to be lacking in some way. + 2) a number of documents that were also deemed relevant for the question. IMPORTANT RULES: - If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. - You may give some additional facts you learned, but do not try to invent an answer. - - If the information is irrelevant, just say "I don't know". - - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. + You may give some additional facts you learned, but do not try to invent an answer. + - If the information is empty or irrelevant, just say "I don't know". + - If the information is relevant but not fully conclusive, provide and answer to the extent you can but also + specify that the information is not conclusive and why. Again, you should be sure that the answer is supported by the information provided! -Try to keep your answer concise. +Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, +or assumptions you made. -Here are is the relevant context information: +Here is the contextual information: \n-------\n + +*Initial Answer that was found to be lacking: +{initial_answer} + +And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n + {relevant_docs} -\n-------\n -And here is the question I want you to answer based on the context above +\n-------\n +\n +Lastly, here is the question I want you to answer based on the information above: \n--\n {question} -\n--\n - +\n--\n\n Answer:""" + ENTITY_TERM_PROMPT = """ \n Based on the original question and the context retieved from a dataset, please generate a list of entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index ba4f6d4228..5f2fb77528 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -8,6 +8,7 @@ from onyx.agent_search.main.models import EntityRelationshipTermExtraction from onyx.context.search.models import InferenceSection +from onyx.db.persona import Persona def normalize_whitespace(text: str) -> str: @@ -107,3 +108,12 @@ def generate_log_message( node_time_str = _format_time_delta(current_time - node_start_time) return f"{graph_time_str} ({node_time_str} s): {message}" + + +def get_persona_prompt(persona: Persona | None) -> str: + if persona is None: + return "" + if len(persona.prompts) > 0: + return persona.prompts[0].system_prompt + else: + return "" From 46850cc2acb1bf83c9711d18376a39231901bed0 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 3 Jan 2025 08:23:48 -0800 Subject: [PATCH 45/78] removal of print statements --- .../agent_search/deep_answer/nodes/answer_generation.py | 7 +++++-- backend/onyx/agent_search/main/edges.py | 3 ++- backend/onyx/agent_search/run_graph.py | 6 ++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py index 67959efbd2..9651b62f88 100644 --- a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py +++ b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py @@ -7,6 +7,9 @@ from onyx.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import normalize_whitespace +from onyx.utils.logger import setup_logger + +logger = setup_logger() # aggregate sub questions and answers @@ -20,14 +23,14 @@ def deep_answer_generation(state: MainState) -> dict[str, Any]: Returns: dict: The updated state with re-phrased question """ - print("---DEEP GENERATE---") + logger.info("---DEEP GENERATE---") question = state["original_question"] docs = state["deduped_retrieval_docs"] deep_answer_context = state["core_answer_dynamic_context"] - print(f"Number of verified retrieval docs - deep: {len(docs)}") + logger.info(f"Number of verified retrieval docs - deep: {len(docs)}") combined_context = normalize_whitespace( COMBINED_CONTEXT.format( diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 8ff525c464..b406697e05 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -21,9 +21,10 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question, + question_nr=question_nr, ), ) - for question in state["initial_decomp_questions"] + for question_nr, question in state["initial_decomp_questions"].items() ] else: diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 246377560a..4a5cdc6c1c 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -15,6 +15,9 @@ from onyx.llm.interfaces import LLM from onyx.tools.models import ToolResponse from onyx.tools.tool_runner import ToolCallKickoff +from onyx.utils.logger import setup_logger + +logger = setup_logger() def _parse_agent_event( @@ -120,5 +123,4 @@ def run_graph( query="What are the guiding principles behind the development of cockroachDB?", ) for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): - print("a") - # print(output) + logger.debug(output) From 6422ad90a55a45924401a78960004c96e43c9067 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 3 Jan 2025 08:26:54 -0800 Subject: [PATCH 46/78] print statements removed (and files saved) --- .../answer_follow_up_question/edges.py | 5 +- .../graph_builder.py | 7 +- .../nodes/answer_generation.py | 5 +- .../agent_search/answer_question/edges.py | 5 +- .../answer_question/graph_builder.py | 7 +- .../nodes/answer_generation.py | 5 +- .../nodes/format_raw_search_results.py | 5 +- .../nodes/generate_raw_search_data.py | 5 +- .../deep_answer/nodes/answer_generation.py | 27 ++--- .../expanded_retrieval/graph_builder.py | 5 +- backend/onyx/agent_search/main/edges.py | 47 +------- .../onyx/agent_search/main/graph_builder.py | 7 +- backend/onyx/agent_search/main/nodes.py | 112 +++++++----------- .../shared_graph_utils/calculations.py | 2 +- .../shared_graph_utils/prompts.py | 4 +- .../agent_search/shared_graph_utils/utils.py | 4 +- 16 files changed, 107 insertions(+), 145 deletions(-) diff --git a/backend/onyx/agent_search/answer_follow_up_question/edges.py b/backend/onyx/agent_search/answer_follow_up_question/edges.py index 34bc48b26b..e69b7a99e6 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/edges.py +++ b/backend/onyx/agent_search/answer_follow_up_question/edges.py @@ -5,10 +5,13 @@ from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.utils.logger import setup_logger + +logger = setup_logger() def send_to_expanded_follow_up_retrieval(state: AnswerQuestionInput) -> Send | Hashable: - print("sending to expanded retrieval for follow up question via edge") + logger.info("sending to expanded retrieval for follow up question via edge") return Send( "decomposed_follow_up_retrieval", diff --git a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py index 47c75da4aa..6f1f77976b 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py @@ -15,6 +15,9 @@ from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) +from onyx.utils.logger import setup_logger + +logger = setup_logger() def answer_follow_up_query_graph_builder() -> StateGraph: @@ -99,6 +102,6 @@ def answer_follow_up_query_graph_builder() -> StateGraph: # debug=True, # subgraphs=True, ): - print(thing) + logger.info(thing) # output = compiled_graph.invoke(inputs) - # print(output) + # logger.info(output) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py index 0403583567..b24564daa0 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py @@ -5,13 +5,16 @@ from onyx.agent_search.answer_question.states import QAGenerationUpdate from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs +from onyx.utils.logger import setup_logger + +logger = setup_logger() def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question = state["question"] docs = state["documents"] - print(f"Number of verified retrieval docs: {len(docs)}") + logger.info(f"Number of verified retrieval docs: {len(docs)}") msg = [ HumanMessage( diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index badfc02f24..441a6d75bb 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -5,10 +5,13 @@ from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.utils.logger import setup_logger + +logger = setup_logger() def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: - print("sending to expanded retrieval via edge") + logger.info("sending to expanded retrieval via edge") return Send( "decomped_expanded_retrieval", diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index e01aa950cb..b86778ab36 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -13,6 +13,9 @@ from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) +from onyx.utils.logger import setup_logger + +logger = setup_logger() def answer_query_graph_builder() -> StateGraph: @@ -97,6 +100,4 @@ def answer_query_graph_builder() -> StateGraph: # debug=True, # subgraphs=True, ): - print(thing) - # output = compiled_graph.invoke(inputs) - # print(output) + logger.info(thing) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index d5e4ff5d67..d795e2fa1a 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -8,6 +8,9 @@ from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.utils.logger import setup_logger + +logger = setup_logger() def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: @@ -22,7 +25,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: persona_prompt=persona_prompt ) - print(f"Number of verified retrieval docs: {len(docs)}") + logger.info(f"Number of verified retrieval docs: {len(docs)}") msg = [ HumanMessage( diff --git a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py index 42e0b45731..6b39fd4fe6 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py @@ -1,9 +1,12 @@ from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.utils.logger import setup_logger + +logger = setup_logger() def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput: - print("format_raw_search_results") + logger.info("format_raw_search_results") return BaseRawSearchOutput( base_expanded_retrieval_result=state["expanded_retrieval_result"], # base_retrieval_results=[state["expanded_retrieval_result"]], diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index 9b6a7b1485..42065503e1 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -1,9 +1,12 @@ from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.utils.logger import setup_logger + +logger = setup_logger() def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: - print("generate_raw_search_data") + logger.info("generate_raw_search_data") return ExpandedRetrievalInput( subgraph_search_request=state["search_request"], subgraph_primary_llm=state["primary_llm"], diff --git a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py index 9651b62f88..6c14665b0a 100644 --- a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py +++ b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py @@ -66,14 +66,13 @@ def final_stuff(state: MainState) -> dict[str, Any]: Returns: dict: The updated state with the agent response appended to messages """ - print("---FINAL---") + logger.info("---FINAL---") messages = state["log_messages"] time_ordered_messages = [x.pretty_repr() for x in messages] time_ordered_messages.sort() - print("Message Log:") - # print("\n".join(time_ordered_messages)) + logger.info("Message Log:") initial_sub_qas = state["initial_sub_qas"] initial_sub_qa_list = [] @@ -87,13 +86,13 @@ def final_stuff(state: MainState) -> dict[str, Any]: base_answer = state["base_answer"] - print(f"Final Base Answer:\n{base_answer}") - print("--------------------------------") - print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") - print("--------------------------------") + logger.info(f"Final Base Answer:\n{base_answer}") + logger.info("--------------------------------") + logger.info(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") + logger.info("--------------------------------") if not state.get("deep_answer"): - print("No Deep Answer was required") + logger.info("No Deep Answer was required") return {} deep_answer = state["deep_answer"] @@ -107,11 +106,11 @@ def final_stuff(state: MainState) -> dict[str, Any]: sub_qa_context = "\n".join(sub_qa_list) - print(f"Final Base Answer:\n{base_answer}") - print("--------------------------------") - print(f"Final Deep Answer:\n{deep_answer}") - print("--------------------------------") - print("Sub Questions and Answers:") - print(sub_qa_context) + logger.info(f"Final Base Answer:\n{base_answer}") + logger.info("--------------------------------") + logger.info(f"Final Deep Answer:\n{deep_answer}") + logger.info("--------------------------------") + logger.info("Sub Questions and Answers:") + logger.info(sub_qa_context) return {} diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 35de29a2f9..dcc67021f7 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -12,6 +12,9 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.utils.logger import setup_logger + +logger = setup_logger() def expanded_retrieval_graph_builder() -> StateGraph: @@ -101,4 +104,4 @@ def expanded_retrieval_graph_builder() -> StateGraph: # debug=True, subgraphs=True, ): - print(thing) + logger.info(thing) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index b406697e05..d440b93053 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -11,6 +11,9 @@ from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState from onyx.agent_search.main.states import RequireRefinedAnswerUpdate +from onyx.utils.logger import setup_logger + +logger = setup_logger() def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: @@ -39,7 +42,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: - print("sending to initial retrieval via edge") + logger.info("sending to initial retrieval via edge") return [ Send( "initial_retrieval", @@ -85,45 +88,3 @@ def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashab ), ) ] - - -# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: -# # Routes re-written queries to the (parallel) retrieval steps -# # Notice the 'Send()' API that takes care of the parallelization -# return [ -# Send( -# "sub_answers_graph", -# ResearchQAState( -# sub_question=sub_question["sub_question_str"], -# sub_question_nr=sub_question["sub_question_nr"], -# graph_start_time=state["graph_start_time"], -# primary_llm=state["primary_llm"], -# fast_llm=state["fast_llm"], -# ), -# ) -# for sub_question in state["sub_questions"] -# ] - - -# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: -# print("---GO TO DEEP ANSWER OR END---") - -# base_answer = state["base_answer"] - -# question = state["original_question"] - -# BASE_CHECK_MESSAGE = [ -# HumanMessage( -# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) -# ) -# ] - -# model = state["fast_llm"] -# response = model.invoke(BASE_CHECK_MESSAGE) - -# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}") - -# if response.pretty_repr() == "no": -# return "decompose" -# else: -# return "end" diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index d8c5e5e4ab..10d7aff449 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -24,9 +24,9 @@ from onyx.agent_search.main.nodes import refined_answer_decision from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState +from onyx.utils.logger import setup_logger -# from onyx.agent_search.main.nodes import check_refined_answer - +logger = setup_logger() test_mode = False @@ -507,5 +507,4 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: # debug=True, subgraphs=True, ): - # print(thing) - print() + logger.info(thing) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 4c99721e91..b42da0e1c7 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -51,6 +51,9 @@ from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.utils.logger import setup_logger + +logger = setup_logger() def main_decomp_base(state: MainState) -> BaseDecompUpdate: @@ -163,7 +166,7 @@ def _calculate_initial_agent_stats( def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: - print("---GENERATE INITIAL---") + logger.info("---GENERATE INITIAL---") question = state["search_request"].query persona_prompt = get_persona_prompt(state["search_request"].persona) @@ -234,15 +237,15 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: state["decomp_answer_results"], state["original_question_retrieval_stats"] ) - print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + logger.info(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") - print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") + logger.info(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n") if initial_agent_stats: - print(initial_agent_stats.original_question) - print(initial_agent_stats.sub_questions) - print(initial_agent_stats.agent_effectiveness) - print("\n\n ---INITIAL AGENT ANSWER END---\n\n") + logger.info(initial_agent_stats.original_question) + logger.info(initial_agent_stats.sub_questions) + logger.info(initial_agent_stats.agent_effectiveness) + logger.info("\n\n ---INITIAL AGENT ANSWER END---\n\n") return InitialAnswerUpdate( initial_answer=answer, @@ -262,28 +265,7 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate InitialAnswerQualityUpdate """ - # print("---CHECK INITIAL QUTPUT QUALITY---") - - # question = state["search_request"].query - # initial_answer = state["initial_answer"] - - # msg = [ - # HumanMessage( - # content=BASE_CHECK_PROMPT.format(question=question, initial_answer=initial_answer) - # ) - # ] - - # model = state["fast_llm"] - # response = model.invoke(msg) - - # if 'yes' in response.content.lower(): - # verdict = True - # else: - # verdict = False - - # print(f"Verdict: {verdict}") - - print("Checking for base answer validity - for not set True/False manually") + logger.info("Checking for base answer validity - for not set True/False manually") verdict = True @@ -291,7 +273,7 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: - print("---GENERATE ENTITIES & TERMS---") + logger.info("---GENERATE ENTITIES & TERMS---") # first four lines duplicates from generate_initial_answer question = state["search_request"].query @@ -370,7 +352,7 @@ def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: - print("---GENERATE INITIAL BASE ANSWER---") + logger.info("---GENERATE INITIAL BASE ANSWER---") question = state["search_request"].query original_question_docs = state["all_original_question_documents"] @@ -389,8 +371,7 @@ def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: response = model.invoke(msg) answer = response.pretty_repr() - print() - print( + logger.info( f"\n\n---INITIAL BASE ANSWER START---\n\nBase: {answer}\n\n ---INITIAL BASE ANSWER END---\n\n" ) return InitialAnswerBASEUpdate(initial_base_answer=answer) @@ -430,7 +411,7 @@ def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpd def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: - print("---REFINED ANSWER DECISION---") + logger.info("---REFINED ANSWER DECISION---") if False: return RequireRefinedAnswerUpdate(require_refined_answer=False) @@ -440,7 +421,7 @@ def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: - print("---GENERATE REFINED ANSWER---") + logger.info("---GENERATE REFINED ANSWER---") question = state["search_request"].query persona_prompt = get_persona_prompt(state["search_request"].persona) @@ -546,22 +527,24 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: revision_question_efficiency=revision_question_efficiency, ) - print(f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}") - print("-" * 10) - print(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + logger.info(f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}") + logger.info("-" * 10) + logger.info(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") - print("-" * 100) - print(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n") - print("-" * 10) - print(f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n") + logger.info("-" * 100) + logger.info(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n") + logger.info("-" * 10) + logger.info( + f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n" + ) - print("-" * 100) + logger.info("-" * 100) - print( + logger.info( f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n" ) - print("-" * 100) + logger.info("-" * 100) if state["initial_agent_stats"]: initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get( @@ -580,21 +563,27 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: "initial_agent_stats" ].sub_questions.get("num_verified_documents", "--") - print("INITIAL AGENT STATS") - print(f"Document Boost Factor: {initial_doc_boost_factor}") - print(f"Support Boost Factor: {initial_support_boost_factor}") - print(f"Originally Verified Docs: {num_initial_verified_docs}") - print(f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}") - print(f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}") + logger.info("INITIAL AGENT STATS") + logger.info(f"Document Boost Factor: {initial_doc_boost_factor}") + logger.info(f"Support Boost Factor: {initial_support_boost_factor}") + logger.info(f"Originally Verified Docs: {num_initial_verified_docs}") + logger.info( + f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}" + ) + logger.info( + f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}" + ) if refined_agent_stats: - print("-" * 10) - print("REFINED AGENT STATS") - print(f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}") - print( + logger.info("-" * 10) + logger.info("REFINED AGENT STATS") + logger.info( + f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}" + ) + logger.info( f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}" ) - print("\n\n ---INITIAL AGENT ANSWER END---\n\n") + logger.info("\n\n ---INITIAL AGENT ANSWER END---\n\n") return RefinedAnswerUpdate( refined_answer=answer, @@ -682,14 +671,5 @@ def ingest_follow_up_answers( def dummy_node(state: RefinedAnswerInput) -> RefinedAnswerOutput: - print("---DUMMY NODE---") + logger.info("---DUMMY NODE---") return {"dummy_output": "this is a dummy output"} - - -# def check_refined_answer(state: MainState) -> RefinedAnswerUpdate: -# print("---CHECK REFINED ANSWER---") - -# return RefinedAnswerUpdate( -# refined_answer="", -# refined_answer_quality=True -# ) diff --git a/backend/onyx/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agent_search/shared_graph_utils/calculations.py index 57deffe28c..b60441bed3 100644 --- a/backend/onyx/agent_search/shared_graph_utils/calculations.py +++ b/backend/onyx/agent_search/shared_graph_utils/calculations.py @@ -52,7 +52,7 @@ def get_fit_scores( ) for rank_type, docs in ranked_sections.items(): - print(f"rank_type: {rank_type}") + logger.info(f"rank_type: {rank_type}") for i in [1, 5, 10]: fit_eval.fit_scores[rank_type].scores[str(i)] = ( diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 1beb961b98..e7776bc7eb 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -20,8 +20,8 @@ # The prompt is only used if there is no persona prompt, so the placeholder is '' BASE_RAG_PROMPT = """ \n - {persona_prompt} - You are an assistant for question-answering tasks. Use the context provided below - and only the + {persona_specification} + Use the context provided below - and only the provided context - to answer the given question. (Note that the answer is in service of anserwing a broader question, given below as 'motivation'.) diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index 5f2fb77528..5d8420b52c 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -113,7 +113,5 @@ def generate_log_message( def get_persona_prompt(persona: Persona | None) -> str: if persona is None: return "" - if len(persona.prompts) > 0: - return persona.prompts[0].system_prompt else: - return "" + return "\n".join([x.system_prompt for x in persona.prompts]) From 328f4758aeb096c5bda97bc18120c8d7cd872cdd Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sat, 4 Jan 2025 19:55:09 -0800 Subject: [PATCH 47/78] persistence of metrics --- .../versions/98a5008d8711_agent_tracking.py | 42 +++++ .../nodes/answer_generation.py | 33 ++-- backend/onyx/agent_search/core_state.py | 2 + backend/onyx/agent_search/main/edges.py | 7 +- .../onyx/agent_search/main/graph_builder.py | 14 +- backend/onyx/agent_search/main/models.py | 34 ++++ backend/onyx/agent_search/main/nodes.py | 148 ++++++++++++++++-- backend/onyx/agent_search/main/states.py | 28 ++-- backend/onyx/agent_search/run_graph.py | 1 + .../shared_graph_utils/agent_prompt_ops.py | 44 ++++++ .../agent_search/shared_graph_utils/models.py | 4 +- .../shared_graph_utils/prompts.py | 27 +++- backend/onyx/db/chat.py | 33 ++++ backend/onyx/db/models.py | 19 +++ 14 files changed, 388 insertions(+), 48 deletions(-) create mode 100644 backend/alembic/versions/98a5008d8711_agent_tracking.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py diff --git a/backend/alembic/versions/98a5008d8711_agent_tracking.py b/backend/alembic/versions/98a5008d8711_agent_tracking.py new file mode 100644 index 0000000000..50e069c874 --- /dev/null +++ b/backend/alembic/versions/98a5008d8711_agent_tracking.py @@ -0,0 +1,42 @@ +"""agent_tracking + +Revision ID: 98a5008d8711 +Revises: 91a0a4d62b14 +Create Date: 2025-01-04 14:41:52.732238 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "98a5008d8711" +down_revision = "91a0a4d62b14" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "agent_search_metrics", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("persona_id", sa.Integer(), nullable=True), + sa.Column("agent_type", sa.String(), nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("base_duration_s", sa.Float(), nullable=False), + sa.Column("full_duration_s", sa.Float(), nullable=False), + sa.Column("base_metrics", postgresql.JSONB(), nullable=True), + sa.Column("refined_metrics", postgresql.JSONB(), nullable=True), + sa.Column("all_metrics", postgresql.JSONB(), nullable=True), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("agent_search_metrics") diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index d795e2fa1a..74bc6e2168 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -1,12 +1,12 @@ -from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.answer_question.states import QAGenerationUpdate +from onyx.agent_search.shared_graph_utils.agent_prompt_ops import ( + build_sub_question_answer_prompt, +) from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA -from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.utils.logger import setup_logger @@ -27,16 +27,23 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: logger.info(f"Number of verified retrieval docs: {len(docs)}") - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format( - question=question, - context=format_docs(docs), - original_question=state["subgraph_search_request"].query, - persona_specification=persona_specification, - ) - ) - ] + msg = build_sub_question_answer_prompt( + question=question, + original_question=state["subgraph_search_request"].query, + docs=docs, + persona_specification=persona_specification, + ) + + # msg = [ + # HumanMessage( + # content=BASE_RAG_PROMPT.format( + # question=question, + # context=format_docs(docs), + # original_question=state["subgraph_search_request"].query, + # persona_specification=persona_specification, + # ) + # ) + # ] fast_llm = state["subgraph_fast_llm"] response = list( diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index a035d25b31..9c512c2843 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from onyx.context.search.models import SearchRequest +from onyx.db.models import User from onyx.llm.interfaces import LLM @@ -20,6 +21,7 @@ class CoreState(TypedDict, total=False): # a single session for the entire agent search # is fine if we are only reading db_session: Session + user: User | None log_messages: Annotated[list[str], add] diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index d440b93053..0179809463 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -1,7 +1,6 @@ from collections.abc import Hashable from typing import Literal -from langgraph.graph import END from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput @@ -27,7 +26,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha question_nr=question_nr, ), ) - for question_nr, question in state["initial_decomp_questions"].items() + for question_nr, question in enumerate(state["initial_decomp_questions"]) ] else: @@ -58,11 +57,11 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: # Define the function that determines whether to continue or not def continue_to_refined_answer_or_end( state: RequireRefinedAnswerUpdate, -) -> Literal["follow_up_decompose", "END"]: +) -> Literal["follow_up_decompose", "logging_node"]: if state["require_refined_answer"]: return "follow_up_decompose" else: - return END + return "logging_node" def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashable]: diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 10d7aff449..0386bfb062 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -20,6 +20,7 @@ from onyx.agent_search.main.nodes import ingest_follow_up_answers from onyx.agent_search.main.nodes import ingest_initial_retrieval from onyx.agent_search.main.nodes import initial_answer_quality_check +from onyx.agent_search.main.nodes import logging_node from onyx.agent_search.main.nodes import main_decomp_base from onyx.agent_search.main.nodes import refined_answer_decision from onyx.agent_search.main.states import MainInput @@ -382,6 +383,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: node="refined_answer_decision", action=refined_answer_decision, ) + + graph.add_node( + node="logging_node", + action=logging_node, + ) # if test_mode: # graph.add_node( # node="generate_initial_base_answer", @@ -434,7 +440,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_conditional_edges( source="refined_answer_decision", path=continue_to_refined_answer_or_end, - path_map=["follow_up_decompose", END], + path_map=["follow_up_decompose", "logging_node"], ) graph.add_conditional_edges( @@ -462,8 +468,14 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: # start_key="refined_answer_subgraph", # end_key="generate_refined_answer", # ) + graph.add_edge( start_key="generate_refined_answer", + end_key="logging_node", + ) + + graph.add_edge( + start_key="logging_node", end_key=END, ) diff --git a/backend/onyx/agent_search/main/models.py b/backend/onyx/agent_search/main/models.py index f23611011b..c59816a10b 100644 --- a/backend/onyx/agent_search/main/models.py +++ b/backend/onyx/agent_search/main/models.py @@ -32,3 +32,37 @@ class FollowUpSubQuestion(BaseModel): verified: bool answered: bool answer: str + + +class AgentTimings(BaseModel): + base_duration_s: float | None + refined_duration_s: float | None + full_duration_s: float | None + + +class AgentBaseMetrics(BaseModel): + num_verified_documents_total: int | None + num_verified_documents_core: int | None + verified_avg_score_core: float | None + num_verified_documents_base: int | float | None + verified_avg_score_base: float | None + base_doc_boost_factor: float | None + support_boost_factor: float | None + duration_s: float | None + + +class AgentRefinedMetrics(BaseModel): + refined_doc_boost_factor: float | None + refined_question_boost_factor: float | None + duration_s: float | None + + +class AgentAdditionalMetrics(BaseModel): + pass + + +class CombinedAgentMetrics(BaseModel): + timings: AgentTimings + base_metrics: AgentBaseMetrics + refined_metrics: AgentRefinedMetrics + additional_metrics: AgentAdditionalMetrics diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index b42da0e1c7..f2663a5ef8 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -1,5 +1,6 @@ import json import re +from datetime import datetime from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs @@ -7,6 +8,11 @@ from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.main.models import AgentAdditionalMetrics +from onyx.agent_search.main.models import AgentBaseMetrics +from onyx.agent_search.main.models import AgentRefinedMetrics +from onyx.agent_search.main.models import AgentTimings +from onyx.agent_search.main.models import CombinedAgentMetrics from onyx.agent_search.main.models import Entity from onyx.agent_search.main.models import EntityRelationshipTermExtraction from onyx.agent_search.main.models import FollowUpSubQuestion @@ -21,9 +27,8 @@ from onyx.agent_search.main.states import InitialAnswerBASEUpdate from onyx.agent_search.main.states import InitialAnswerQualityUpdate from onyx.agent_search.main.states import InitialAnswerUpdate +from onyx.agent_search.main.states import MainOutput from onyx.agent_search.main.states import MainState -from onyx.agent_search.main.states import RefinedAnswerInput -from onyx.agent_search.main.states import RefinedAnswerOutput from onyx.agent_search.main.states import RefinedAnswerUpdate from onyx.agent_search.main.states import RequireRefinedAnswerUpdate from onyx.agent_search.shared_graph_utils.models import AgentChunkStats @@ -51,12 +56,15 @@ from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.db.chat import log_agent_metrics from onyx.utils.logger import setup_logger logger = setup_logger() def main_decomp_base(state: MainState) -> BaseDecompUpdate: + agent_start_time = datetime.now() + question = state["search_request"].query get_persona_prompt(state["search_request"].persona) @@ -79,6 +87,7 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: return BaseDecompUpdate( initial_decomp_questions=decomp_list, + agent_start_time=agent_start_time, ) @@ -201,7 +210,10 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: ) ) - sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + if len(good_qa_list) > 0: + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + else: + sub_question_answer_str = "" # Determine which persona-specification prompt to use @@ -247,10 +259,37 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: logger.info(initial_agent_stats.agent_effectiveness) logger.info("\n\n ---INITIAL AGENT ANSWER END---\n\n") + agent_base_end_time = datetime.now() + + agent_base_metrics = AgentBaseMetrics( + num_verified_documents_total=len(relevant_docs), + num_verified_documents_core=state[ + "original_question_retrieval_stats" + ].verified_count, + verified_avg_score_core=state[ + "original_question_retrieval_stats" + ].verified_avg_scores, + num_verified_documents_base=initial_agent_stats.sub_questions.get( + "num_verified_documents", None + ), + verified_avg_score_base=initial_agent_stats.sub_questions.get( + "verified_avg_score", None + ), + base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get( + "utilized_chunk_ratio", None + ), + support_boost_factor=initial_agent_stats.agent_effectiveness.get( + "support_ratio", None + ), + duration_s=(agent_base_end_time - state["agent_start_time"]).total_seconds(), + ) + return InitialAnswerUpdate( initial_answer=answer, initial_agent_stats=initial_agent_stats, generated_sub_questions=decomp_questions, + agent_base_end_time=agent_base_end_time, + agent_base_metrics=agent_base_metrics, ) @@ -471,9 +510,14 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: total_good_sub_questions = list( set(initial_good_sub_questions + new_revised_good_sub_questions) ) - revision_question_efficiency = len(total_good_sub_questions) / len( - initial_good_sub_questions - ) + if len(initial_good_sub_questions) > 0: + revision_question_efficiency: float = len(total_good_sub_questions) / len( + initial_good_sub_questions + ) + elif len(new_revised_good_sub_questions) > 0: + revision_question_efficiency: float = 10.0 + else: + revision_question_efficiency: float = 1.0 sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list))) @@ -585,16 +629,31 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: logger.info("\n\n ---INITIAL AGENT ANSWER END---\n\n") + agent_refined_end_time = datetime.now() + agent_refined_duration = ( + agent_refined_end_time - state["agent_refined_start_time"] + ).total_seconds() + + agent_refined_metrics = AgentRefinedMetrics( + refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency, + refined_question_boost_factor=refined_agent_stats.revision_question_efficiency, + duration_s=agent_refined_duration, + ) + return RefinedAnswerUpdate( refined_answer=answer, refined_answer_quality=True, # TODO: replace this with the actual check value refined_agent_stats=refined_agent_stats, + agent_refined_end_time=agent_refined_end_time, + agent_refined_metrics=agent_refined_metrics, ) def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: """ """ + agent_refined_start_time = datetime.now() + question = state["search_request"].query base_answer = state["initial_answer"] @@ -651,7 +710,8 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: follow_up_sub_question_dict[sub_question_nr] = follow_up_sub_question return FollowUpSubQuestionsUpdate( - follow_up_sub_questions=follow_up_sub_question_dict + follow_up_sub_questions=follow_up_sub_question_dict, + agent_refined_start_time=agent_refined_start_time, ) @@ -670,6 +730,74 @@ def ingest_follow_up_answers( ) -def dummy_node(state: RefinedAnswerInput) -> RefinedAnswerOutput: - logger.info("---DUMMY NODE---") - return {"dummy_output": "this is a dummy output"} +def logging_node(state: MainState) -> MainOutput: + logger.info("---LOGGING NODE---") + + agent_start_time = state["agent_start_time"] + agent_base_end_time = state["agent_base_end_time"] + agent_refined_start_time = state["agent_refined_start_time"] + agent_refined_end_time = state["agent_refined_end_time"] + agent_end_time = max(agent_base_end_time, agent_refined_end_time) + + if agent_base_end_time: + agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds() + else: + agent_base_duration = None + + if agent_refined_end_time: + agent_refined_duration = ( + agent_refined_end_time - agent_refined_start_time + ).total_seconds() + else: + agent_refined_duration = None + + if agent_end_time: + agent_full_duration = (agent_end_time - agent_start_time).total_seconds() + else: + agent_full_duration = None + + if agent_refined_duration: + agent_type = "refined" + else: + agent_type = "base" + + agent_base_metrics = state["agent_base_metrics"] + agent_refined_metrics = state["agent_refined_metrics"] + + combined_agent_metrics = CombinedAgentMetrics( + timings=AgentTimings( + base_duration_s=agent_base_duration, + refined_duration_s=agent_refined_duration, + full_duration_s=agent_full_duration, + ), + base_metrics=agent_base_metrics, + refined_metrics=agent_refined_metrics, + additional_metrics=AgentAdditionalMetrics(), + ) + + if state["search_request"].persona: + persona_id = state["search_request"].persona.id + else: + persona_id = None + + if "user" in state: + if state["user"]: + user_id = state["user"].id + else: + user_id = None + else: + user_id = None + + # log the agent metrics + log_agent_metrics( + db_session=state["db_session"], + user_id=user_id, + persona_id=persona_id, + agent_type=agent_type, + start_time=agent_start_time, + agent_metrics=combined_agent_metrics, + ) + + main_output = MainOutput() + + return main_output diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index dabc195745..a89495d955 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -1,3 +1,4 @@ +from datetime import datetime from operator import add from typing import Annotated from typing import TypedDict @@ -6,6 +7,8 @@ from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.main.models import AgentBaseMetrics +from onyx.agent_search.main.models import AgentRefinedMetrics from onyx.agent_search.main.models import EntityRelationshipTermExtraction from onyx.agent_search.main.models import FollowUpSubQuestion from onyx.agent_search.shared_graph_utils.models import AgentChunkStats @@ -14,12 +17,14 @@ from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection + ### States ### ## Update States class BaseDecompUpdate(TypedDict): + agent_start_time: datetime initial_decomp_questions: list[str] @@ -31,12 +36,16 @@ class InitialAnswerUpdate(TypedDict): initial_answer: str initial_agent_stats: InitialAgentResultStats | None generated_sub_questions: list[str] + agent_base_end_time: datetime + agent_base_metrics: AgentBaseMetrics class RefinedAnswerUpdate(TypedDict): refined_answer: str refined_agent_stats: RefinedAgentStats | None refined_answer_quality: bool + agent_refined_end_time: datetime + agent_refined_metrics: AgentRefinedMetrics class InitialAnswerQualityUpdate(TypedDict): @@ -71,6 +80,7 @@ class EntityTermExtractionUpdate(TypedDict): class FollowUpSubQuestionsUpdate(TypedDict): follow_up_sub_questions: dict[int, FollowUpSubQuestion] + agent_refined_start_time: datetime class FollowUpAnswerQuestionOutput(TypedDict): @@ -113,24 +123,8 @@ class MainState( base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add] -## Graph Output State +## Graph Output State - presently not used class MainOutput(TypedDict): - initial_answer: str - initial_base_answer: str - initial_agent_stats: dict - generated_sub_questions: list[str] - require_refined_answer: bool - - -class RefinedAnswerInput(MainState): - pass - - -class RefinedAnswerOutput(TypedDict): - dummy_output: str - - -class RefinedAnswerState(RefinedAnswerInput, RefinedAnswerOutput): pass diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 4a5cdc6c1c..e498d494ad 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -121,6 +121,7 @@ def run_graph( primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( query="What are the guiding principles behind the development of cockroachDB?", + # query="What are the tempereatures in Munich and New York?", ) for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): logger.debug(output) diff --git a/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py b/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py new file mode 100644 index 0000000000..29bbeffaee --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py @@ -0,0 +1,44 @@ +from langchain.schema import AIMessage +from langchain.schema import HumanMessage +from langchain.schema import SystemMessage +from langchain_core.messages.tool import ToolMessage + +from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2 +from onyx.context.search.models import InferenceSection + + +def build_sub_question_answer_prompt( + question: str, + original_question: str, + docs: list[InferenceSection], + persona_specification: str, +) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]: + system_message = SystemMessage( + content=persona_specification, + ) + + docs_format_list = [ + f"""Document Number: [{doc_nr + 1}]\n + Content: {doc.combined_content}\n\n""" + for doc_nr, doc in enumerate(docs) + ] + + docs_str = "\n\n".join(docs_format_list) + + human_message = HumanMessage( + content=BASE_RAG_PROMPT_v2.format( + question=question, original_question=original_question, context=docs_str + ) + ) + + # ai_message = AIMessage(content='' + # ) + + # tool_message = ToolMessage( + # content=docs_str, + # tool_call_id='agent_search_call', + # name="search_results" + # ) + + return [system_message, human_message] + # return [system_message, human_message, ai_message, tool_message] diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py index 61e44d03f4..0555f03e80 100644 --- a/backend/onyx/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -48,5 +48,5 @@ class InitialAgentResultStats(BaseModel): class RefinedAgentStats(BaseModel): - revision_doc_efficiency: float | int | None - revision_question_efficiency: float | int | None + revision_doc_efficiency: float | None + revision_question_efficiency: float | None diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index e7776bc7eb..a02e69801f 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -40,6 +40,31 @@ \n--\n {question} \n--\n """ +BASE_RAG_PROMPT_v2 = """ \n + Use the context provided below - and only the + provided context - to answer the given question. (Note that the answer is in service of anserwing a broader + question, given below as 'motivation'.) + + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. + (But keep other details as well.) + + Remember to provide inline citations in the format [1], [2], [3], etc.\n\n\n + + For your general information, here is the ultimate motivation: + \n--\n {original_question} \n--\n + \n\n + And here is the actual question I want you to answer based on the context above (with the motivation in mind): + \n--\n {question} \n--\n + + Here is the context: + \n\n\n--\n {context} \n--\n + """ + + SUB_CHECK_PROMPT = """ Your task is to see whether a given answer addresses a given question. Please do not use any internal knowledge you may have - just focus on whether the answer @@ -507,7 +532,7 @@ Answer:""" # sub_question_answer_str is empty -INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """{sub_question_answer_str} +INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """{answered_sub_questions} {persona_specification} Use the information provided below - and only the provided information - to answer the provided question. diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index a2cda0a30c..f432f3b444 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session +from onyx.agent_search.main.models import CombinedAgentMetrics from onyx.auth.schemas import UserRole from onyx.chat.models import DocumentRelevance from onyx.configs.chat_configs import HARD_DELETE_CHATS @@ -22,6 +23,7 @@ from onyx.context.search.models import RetrievalDocs from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SearchDoc as ServerSearchDoc +from onyx.db.models import AgentSearchMetrics from onyx.db.models import ChatMessage from onyx.db.models import ChatMessage__SearchDoc from onyx.db.models import ChatSession @@ -863,3 +865,34 @@ def translate_db_message_to_chat_message_detail( ) return chat_msg_detail + + +def log_agent_metrics( + db_session: Session, + user_id: UUID | None, + persona_id: int | None, # Can be none if temporary persona is used + agent_type: str, + start_time: datetime, + agent_metrics: CombinedAgentMetrics, +) -> AgentSearchMetrics: + agent_timings = agent_metrics.timings + agent_base_metrics = agent_metrics.base_metrics + agent_refined_metrics = agent_metrics.refined_metrics + agent_additional_metrics = agent_metrics.additional_metrics + + agent_metric_tracking = AgentSearchMetrics( + user_id=user_id, + persona_id=persona_id, + agent_type=agent_type, + start_time=start_time, + base_duration_s=agent_timings.base_duration_s, + full_duration_s=agent_timings.full_duration_s, + base_metrics=vars(agent_base_metrics), + refined_metrics=vars(agent_refined_metrics), + all_metrics=vars(agent_additional_metrics), + ) + + db_session.add(agent_metric_tracking) + db_session.commit() + + return agent_metric_tracking diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 353eba9c9d..d14f2e4892 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1598,6 +1598,25 @@ class PGFileStore(Base): lobj_oid: Mapped[int] = mapped_column(Integer, nullable=False) +class AgentSearchMetrics(Base): + __tablename__ = "agent_search_metrics" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) + persona_id: Mapped[int | None] = mapped_column( + ForeignKey("persona.id"), nullable=True + ) + agent_type: Mapped[str] = mapped_column(String) + start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) + base_duration_s: Mapped[float] = mapped_column(Float) + full_duration_s: Mapped[float] = mapped_column(Float) + base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + + """ ************************************************************************ Enterprise Edition Models From ae65c739de13c7ec5a3378409b81b850382cedd0 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Sun, 5 Jan 2025 10:55:32 -0800 Subject: [PATCH 48/78] WIP, deleted persistence code, partway through streaming code --- .../answer_question/graph_builder.py | 9 ++ .../nodes/answer_generation.py | 20 ++- .../nodes/generate_raw_search_data.py | 4 +- backend/onyx/agent_search/core_state.py | 6 +- backend/onyx/agent_search/db_operations.py | 64 +++++++++ .../agent_search/expanded_retrieval/edges.py | 2 +- .../expanded_retrieval/graph_builder.py | 10 ++ .../agent_search/expanded_retrieval/nodes.py | 46 +++++- backend/onyx/agent_search/main/edges.py | 27 ++-- .../onyx/agent_search/main/graph_builder.py | 12 +- backend/onyx/agent_search/main/nodes.py | 62 +++++++-- backend/onyx/agent_search/run_graph.py | 131 ++++++------------ .../shared_graph_utils/prompts.py | 8 +- .../agent_search/shared_graph_utils/utils.py | 87 +++++++++++- backend/onyx/chat/answer.py | 20 +-- backend/onyx/chat/models.py | 26 +++- backend/onyx/chat/process_message.py | 15 +- backend/onyx/db/models.py | 74 ++++++++++ 18 files changed, 480 insertions(+), 143 deletions(-) create mode 100644 backend/onyx/agent_search/db_operations.py diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index e01aa950cb..834fe9f498 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -13,6 +13,7 @@ from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) +from onyx.agent_search.shared_graph_utils.utils import get_test_config def answer_query_graph_builder() -> StateGraph: @@ -89,8 +90,16 @@ def answer_query_graph_builder() -> StateGraph: query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: + pro_search_config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) inputs = AnswerQuestionInput( question="what can you do with onyx?", + subgraph_fast_llm=fast_llm, + subgraph_primary_llm=primary_llm, + subgraph_config=pro_search_config, + subgraph_search_tool=search_tool, + subgraph_db_session=db_session, ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index b25c9f6072..5aaf7f0323 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -21,7 +21,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: content=BASE_RAG_PROMPT.format( question=question, context=format_docs(docs), - original_question=state["subgraph_search_request"].query, + original_question=state["subgraph_config"].search_request.query, ) ) ] @@ -38,6 +38,24 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: response.append(message.content) answer_str = merge_message_runs(response, chunk_separator="")[0].content + + if state["subgraph_config"].use_persistence: + # Persist the sub-answer in the database + # db_session = state["subgraph_db_session"] + # chat_session_id = state["subgraph_config"].chat_session_id + # primary_message_id = state["subgraph_config"].message_id + # sub_question_id = state["sub_question_id"] + + # if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None: + # create_sub_answer( + # db_session=db_session, + # chat_session_id=chat_session_id, + # primary_message_id=primary_message_id, + # sub_question_id=sub_question_id, + # answer=answer_str, + # ) + pass + return QAGenerationUpdate( answer=answer_str, ) diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index 259105e428..5fd5622c7c 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -5,11 +5,11 @@ def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: print("generate_raw_search_data") return ExpandedRetrievalInput( - subgraph_search_request=state["search_request"], + subgraph_config=state["config"], subgraph_primary_llm=state["primary_llm"], subgraph_fast_llm=state["fast_llm"], subgraph_db_session=state["db_session"], - question=state["search_request"].query, + question=state["config"].search_request.query, base_search=True, subgraph_search_tool=state["search_tool"], ) diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index 2868e2176f..52d7e3a720 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session -from onyx.context.search.models import SearchRequest +from onyx.chat.models import ProSearchConfig from onyx.llm.interfaces import LLM from onyx.tools.tool_implementations.search.search_tool import SearchTool @@ -15,7 +15,7 @@ class CoreState(TypedDict, total=False): This is the core state that is shared across all subgraphs. """ - search_request: SearchRequest + config: ProSearchConfig primary_llm: LLM fast_llm: LLM # a single session for the entire agent search @@ -30,7 +30,7 @@ class SubgraphCoreState(TypedDict, total=False): This is the core state that is shared across all subgraphs. """ - subgraph_search_request: SearchRequest + subgraph_config: ProSearchConfig subgraph_primary_llm: LLM subgraph_fast_llm: LLM # a single session for the entire agent search diff --git a/backend/onyx/agent_search/db_operations.py b/backend/onyx/agent_search/db_operations.py new file mode 100644 index 0000000000..999fd40d5f --- /dev/null +++ b/backend/onyx/agent_search/db_operations.py @@ -0,0 +1,64 @@ +from uuid import UUID + +from sqlalchemy.orm import Session + +from onyx.db.models import SubQuery +from onyx.db.models import SubQuestion + + +def create_sub_question( + db_session: Session, + chat_session_id: UUID, + primary_message_id: int, + sub_question: str, +) -> SubQuestion: + """Create a new sub-question record in the database.""" + sub_q = SubQuestion( + chat_session_id=chat_session_id, + primary_question_id=primary_message_id, + sub_question=sub_question, + ) + db_session.add(sub_q) + db_session.flush() + return sub_q + + +def create_sub_query( + db_session: Session, + chat_session_id: UUID, + parent_question_id: int, + sub_query: str, +) -> SubQuery: + """Create a new sub-query record in the database.""" + sub_q = SubQuery( + chat_session_id=chat_session_id, + parent_question_id=parent_question_id, + sub_query=sub_query, + ) + db_session.add(sub_q) + db_session.flush() + return sub_q + + +def get_sub_questions_for_message( + db_session: Session, + primary_message_id: int, +) -> list[SubQuestion]: + """Get all sub-questions for a given primary message.""" + return ( + db_session.query(SubQuestion) + .filter(SubQuestion.primary_question_id == primary_message_id) + .all() + ) + + +def get_sub_queries_for_question( + db_session: Session, + sub_question_id: int, +) -> list[SubQuery]: + """Get all sub-queries for a given sub-question.""" + return ( + db_session.query(SubQuery) + .filter(SubQuery.parent_question_id == sub_question_id) + .all() + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index eb9d1bc2b5..3d3b06374d 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -8,7 +8,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]: - question = state.get("question", state["subgraph_search_request"].query) + question = state.get("question", state["subgraph_config"].search_request.query) query_expansions = state.get("expanded_queries", []) + [question] return [ diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 35de29a2f9..a5325211c4 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -12,6 +12,7 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.shared_graph_utils.utils import get_test_config def expanded_retrieval_graph_builder() -> StateGraph: @@ -91,10 +92,19 @@ def expanded_retrieval_graph_builder() -> StateGraph: search_request = SearchRequest( query="what can you do with onyx or danswer?", ) + with get_session_context_manager() as db_session: + pro_search_config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) inputs = ExpandedRetrievalInput( question="what can you do with onyx?", base_search=False, + subgraph_fast_llm=fast_llm, + subgraph_primary_llm=primary_llm, + subgraph_db_session=db_session, + subgraph_config=pro_search_config, + subgraph_search_tool=search_tool, ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 927b6b379c..7a00c8a318 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -42,8 +42,33 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: - question = state.get("question") + # Sometimes we want to expand the original question, sometimes we want to expand a sub-question. + # When we are running this node on the original question, no question is explictly passed in. + # Instead, we use the original question from the search request. + question = state.get("question", state["subgraph_config"].search_request.query) llm: LLM = state["subgraph_fast_llm"] + state["subgraph_db_session"] + chat_session_id = state["subgraph_config"].chat_session_id + sub_question_id = state.get("sub_question_id") + + if chat_session_id is None: + raise ValueError("chat_session_id must be provided for agent search") + + if sub_question_id is None: + if state["subgraph_config"].use_persistence: + # in this case, we are doing retrieval on the original question. + # to make all the logic consistent (i.e. all subqueries have a + # subquestion as a parent), we create a new sub-question + # with the same content as the original question. + # if state["subgraph_config"].message_id is None: + # raise ValueError("message_id must be provided for agent search with persistence") + # sub_question_id = create_sub_question(db_session, + # chat_session_id, + # state["subgraph_config"].message_id, + # question).id + pass + else: + sub_question_id = 1 msg = [ HumanMessage( @@ -64,6 +89,20 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: rewritten_queries = llm_response.split("--") + if state["subgraph_config"].use_persistence: + # Persist sub-queries to database + + # for query in rewritten_queries: + # sub_queries.append( + # create_sub_query( + # db_session=db_session, + # chat_session_id=chat_session_id, + # parent_question_id=sub_question_id, + # sub_query=query.strip(), + # ) + # ) + pass + return QueryExpansionUpdate( expanded_queries=rewritten_queries, ) @@ -123,9 +162,10 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: def verification_kickoff( state: ExpandedRetrievalState, ) -> Command[Literal["doc_verification"]]: + # TODO: stream deduped docs? documents = state["retrieved_documents"] verification_question = state.get( - "question", state["subgraph_search_request"].query + "question", state["subgraph_config"].search_request.query ) return Command( update={}, @@ -186,7 +226,7 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: # Rerank post retrieval and verification. First, create a search query # then create the list of reranked sections - question = state.get("question", state["subgraph_search_request"].query) + question = state.get("question", state["subgraph_config"].search_request.query) _search_query = retrieval_preprocessing( search_request=SearchRequest(query=question), user=None, diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 22cba21933..e03c146ea2 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -5,19 +5,28 @@ from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.core_state import extract_core_fields_for_subgraph -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: if len(state["initial_decomp_questions"]) > 0: + # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] + # if len(state["sub_question_records"]) == 0: + # if state["config"].use_persistence: + # raise ValueError("No sub-questions found for initial decompozed questions") + # else: + # # in this case, we are doing retrieval on the original question. + # # to make all the logic consistent, we create a new sub-question + # # with the same content as the original question + # sub_question_record_ids = [1] * len(state["initial_decomp_questions"]) + return [ Send( "answer_query", AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question, + # sub_question_id=sub_question_record_id, ), ) for question in state["initial_decomp_questions"] @@ -34,20 +43,6 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha ] -def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: - print("sending to initial retrieval via edge") - return [ - Send( - "initial_retrieval", - ExpandedRetrievalInput( - question=state["search_request"].query, - **extract_core_fields_for_subgraph(state), - base_search=False, - ), - ) - ] - - # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: # # Routes re-written queries to the (parallel) retrieval steps # # Notice the 'Send()' API that takes care of the parallelization diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 50a839cb21..3600b9df4b 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -13,7 +13,7 @@ from onyx.agent_search.main.nodes import main_decomp_base from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState - +from onyx.agent_search.shared_graph_utils.utils import get_test_config test_mode = False @@ -437,12 +437,16 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: with get_session_context_manager() as db_session: search_request = SearchRequest(query="Who created Excel?") + pro_search_config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) inputs = MainInput( - search_request=search_request, primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, + config=pro_search_config, + search_tool=search_tool, ) for thing in compiled_graph.stream( @@ -451,5 +455,5 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: # debug=True, subgraphs=True, ): - # print(thing) - print() + print(thing) + # print() diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 2848284998..5ddbfd8da2 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -25,12 +25,30 @@ from onyx.agent_search.shared_graph_utils.prompts import ( INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, ) -from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string from onyx.agent_search.shared_graph_utils.utils import format_docs +from onyx.chat.models import SubQuestion + + +def dispatch_subquestion(sub_question_part: str, subq_id: int) -> None: + dispatch_custom_event( + "decomp_qs", + SubQuestion( + sub_question=sub_question_part, + question_id=subq_id, + ), + ) def main_decomp_base(state: MainState) -> BaseDecompUpdate: - question = state["search_request"].query + question = state["config"].search_request.query + state["db_session"] + chat_session_id = state["config"].chat_session_id + primary_message_id = state["config"].message_id + + if not chat_session_id or not primary_message_id: + raise ValueError( + "chat_session_id and message_id must be provided for agent search" + ) msg = [ HumanMessage( @@ -41,12 +59,18 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: # Get the rewritten queries in a defined format model = state["fast_llm"] streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + subq_id = 1 for message in model.stream(msg): - dispatch_custom_event( - "decomp_qs", - message.content, - ) - streamed_tokens.append(message.content) + content = cast(str, message.content) + if "\n" in content: + for sub_question_part in content.split("\n"): + dispatch_subquestion(sub_question_part, subq_id) + subq_id += 1 + subq_id -= 1 # fencepost; extra increment at end of loop + else: + dispatch_subquestion(content, subq_id) + + streamed_tokens.append(content) response = merge_content(*streamed_tokens) @@ -54,12 +78,28 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: # assert [type(tok) == str for tok in streamed_tokens] # use no-op cast() instead of str() which runs code - list_of_subquestions = clean_and_parse_list_string(cast(str, response)) + # list_of_subquestions = clean_and_parse_list_string(cast(str, response)) + list_of_subquestions = cast(str, response).split("\n") decomp_list: list[str] = [ - sub_question["sub_question"].strip() for sub_question in list_of_subquestions + sub_question.strip() for sub_question in list_of_subquestions ] + # Persist sub-questions to database + # from onyx.agent_search.db_operations import create_sub_question + + if state["config"].use_persistence: + # for sub_q in decomp_list: + # sub_questions.append( + # create_sub_question( + # db_session=db_session, + # chat_session_id=chat_session_id, + # primary_message_id=primary_message_id, + # sub_question=sub_q, + # ) + # ) + pass + return BaseDecompUpdate( initial_decomp_questions=decomp_list, ) @@ -150,7 +190,7 @@ def _calculate_initial_agent_stats( def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print("---GENERATE INITIAL---") - question = state["search_request"].query + question = state["config"].search_request.query sub_question_docs = state["documents"] all_original_question_documents = state["all_original_question_documents"] relevant_docs = dedup_inference_sections( @@ -242,7 +282,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: print("---GENERATE INITIAL BASE ANSWER---") - question = state["search_request"].query + question = state["config"].search_request.query original_question_docs = state["all_original_question_documents"] msg = [ diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 2b437c8fc6..77309dea4c 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -5,30 +5,25 @@ from langchain_core.runnables.schema import StreamEvent from langgraph.graph.state import CompiledStateGraph +from sqlalchemy.orm import Session from onyx.agent_search.main.graph_builder import main_graph_builder from onyx.agent_search.main.states import MainInput +from onyx.agent_search.shared_graph_utils.utils import get_test_config from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStream -from onyx.chat.models import AnswerStyleConfig -from onyx.chat.models import CitationConfig -from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import OnyxAnswerPiece -from onyx.chat.models import PromptConfig -from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE -from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT -from onyx.configs.constants import DEFAULT_PERSONA_ID -from onyx.context.search.enums import LLMEvaluationType -from onyx.context.search.models import RetrievalDetails +from onyx.chat.models import ProSearchConfig +from onyx.chat.models import SubQuestion from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager -from onyx.db.persona import get_persona_by_id from onyx.llm.interfaces import LLM from onyx.tools.models import ToolResponse -from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_runner import ToolCallKickoff +_COMPILED_GRAPH: CompiledStateGraph | None = None + def _parse_agent_event( event: StreamEvent, @@ -43,11 +38,14 @@ def _parse_agent_event( if event_type == "on_custom_event": # TODO: different AnswerStream types for different events if event["name"] == "decomp_qs": - return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + # return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + return cast(SubQuestion, event["data"]) elif event["name"] == "subqueries": - return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + # return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + return None elif event["name"] == "sub_answers": - return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + # return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + return None elif event["name"] == "main_answer": return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) elif event["name"] == "tool_response": @@ -92,40 +90,51 @@ def _yield_async_to_sync() -> Iterable[StreamEvent]: def run_graph( compiled_graph: CompiledStateGraph, - search_request: SearchRequest, + config: ProSearchConfig, search_tool: SearchTool, primary_llm: LLM, fast_llm: LLM, + db_session: Session, ) -> AnswerStream: - with get_session_context_manager() as db_session: - input = MainInput( - search_request=search_request, - primary_llm=primary_llm, - fast_llm=fast_llm, - db_session=db_session, - search_tool=search_tool, - ) - for event in _manage_async_event_streaming( - compiled_graph=compiled_graph, graph_input=input - ): - if parsed_object := _parse_agent_event(event): - yield parsed_object + input = MainInput( + config=config, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + search_tool=search_tool, + ) + for event in _manage_async_event_streaming( + compiled_graph=compiled_graph, graph_input=input + ): + if parsed_object := _parse_agent_event(event): + yield parsed_object + + +# TODO: call this once on startup, TBD where and if it should be gated based +# on dev mode or not +def load_compiled_graph() -> CompiledStateGraph: + global _COMPILED_GRAPH + if _COMPILED_GRAPH is None: + graph = main_graph_builder() + _COMPILED_GRAPH = graph.compile() + return _COMPILED_GRAPH def run_main_graph( - search_request: SearchRequest, + config: ProSearchConfig, search_tool: SearchTool, primary_llm: LLM, fast_llm: LLM, + db_session: Session, ) -> AnswerStream: - graph = main_graph_builder() - compiled_graph = graph.compile() - return run_graph(compiled_graph, search_request, search_tool, primary_llm, fast_llm) + compiled_graph = load_compiled_graph() + return run_graph( + compiled_graph, config, search_tool, primary_llm, fast_llm, db_session + ) if __name__ == "__main__": from onyx.llm.factory import get_default_llms - from onyx.context.search.models import SearchRequest graph = main_graph_builder() compiled_graph = graph.compile() @@ -134,64 +143,14 @@ def run_main_graph( query="what can you do with gitlab?", ) with get_session_context_manager() as db_session: - persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) - document_pruning_config = DocumentPruningConfig( - max_chunks=int( - persona.num_chunks - if persona.num_chunks is not None - else MAX_CHUNKS_FED_TO_CHAT - ), - max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE, - ) - - answer_style_config = AnswerStyleConfig( - citation_config=CitationConfig( - # The docs retrieved by this flow are already relevance-filtered - all_docs_useful=True - ), - document_pruning_config=document_pruning_config, - structured_response_format=None, - ) - - search_tool_config = SearchToolConfig( - answer_style_config=answer_style_config, - document_pruning_config=document_pruning_config, - retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True - rerank_settings=None, # Can use this to change reranking model - selected_sections=None, - latest_query_files=None, - bypass_acl=False, - ) - - prompt_config = PromptConfig.from_model(persona.prompts[0]) - - search_tool = SearchTool( - db_session=db_session, - user=None, - persona=persona, - retrieval_options=search_tool_config.retrieval_options, - prompt_config=prompt_config, - llm=primary_llm, - fast_llm=fast_llm, - pruning_config=search_tool_config.document_pruning_config, - answer_style_config=search_tool_config.answer_style_config, - selected_sections=search_tool_config.selected_sections, - chunks_above=search_tool_config.chunks_above, - chunks_below=search_tool_config.chunks_below, - full_doc=search_tool_config.full_doc, - evaluation_type=( - LLMEvaluationType.BASIC - if persona.llm_relevance_filter - else LLMEvaluationType.SKIP - ), - rerank_settings=search_tool_config.rerank_settings, - bypass_acl=search_tool_config.bypass_acl, + config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request ) with open("output.txt", "w") as f: tool_responses = [] for output in run_graph( - compiled_graph, search_request, search_tool, primary_llm, fast_llm + compiled_graph, config, search_tool, primary_llm, fast_llm, db_session ): if isinstance(output, OnyxAnswerPiece): f.write(str(output.answer_piece) + "|") diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 07b935d91b..f3a45c1083 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -385,14 +385,16 @@ (i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X', 'what do we do to improve stability of product X', ...]) -If you think that a decomposition is not needed or helpful, please just return an empty list. That is ok too. +If you think that a decomposition is not needed or helpful, please just return an empty string. That is ok too. Here is the initial question: ------- {question} ------- -Please formulate your answer as a list of json objects with the following format: -[{{"sub_question": }}, ...] +Please formulate your answer as a newline-separated list of questions like so: + + + Answer:""" diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index a435860320..fe7a205f75 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -5,8 +5,26 @@ from datetime import datetime from datetime import timedelta from typing import Any - +from uuid import UUID + +from sqlalchemy.orm import Session + +from onyx.chat.models import AnswerStyleConfig +from onyx.chat.models import CitationConfig +from onyx.chat.models import DocumentPruningConfig +from onyx.chat.models import PromptConfig +from onyx.chat.models import ProSearchConfig +from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE +from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT +from onyx.configs.constants import DEFAULT_PERSONA_ID +from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.models import InferenceSection +from onyx.context.search.models import RetrievalDetails +from onyx.context.search.models import SearchRequest +from onyx.db.persona import get_persona_by_id +from onyx.llm.interfaces import LLM +from onyx.tools.tool_constructor import SearchToolConfig +from onyx.tools.tool_implementations.search.search_tool import SearchTool def normalize_whitespace(text: str) -> str: @@ -99,3 +117,70 @@ def generate_log_message( node_time_str = _format_time_delta(current_time - node_start_time) return f"{graph_time_str} ({node_time_str} s): {message}" + + +def get_test_config( + db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest +) -> tuple[ProSearchConfig, SearchTool]: + persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else MAX_CHUNKS_FED_TO_CHAT + ), + max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE, + ) + + answer_style_config = AnswerStyleConfig( + citation_config=CitationConfig( + # The docs retrieved by this flow are already relevance-filtered + all_docs_useful=True + ), + document_pruning_config=document_pruning_config, + structured_response_format=None, + ) + + search_tool_config = SearchToolConfig( + answer_style_config=answer_style_config, + document_pruning_config=document_pruning_config, + retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True + rerank_settings=None, # Can use this to change reranking model + selected_sections=None, + latest_query_files=None, + bypass_acl=False, + ) + + prompt_config = PromptConfig.from_model(persona.prompts[0]) + + search_tool = SearchTool( + db_session=db_session, + user=None, + persona=persona, + retrieval_options=search_tool_config.retrieval_options, + prompt_config=prompt_config, + llm=primary_llm, + fast_llm=fast_llm, + pruning_config=search_tool_config.document_pruning_config, + answer_style_config=search_tool_config.answer_style_config, + selected_sections=search_tool_config.selected_sections, + chunks_above=search_tool_config.chunks_above, + chunks_below=search_tool_config.chunks_below, + full_doc=search_tool_config.full_doc, + evaluation_type=( + LLMEvaluationType.BASIC + if persona.llm_relevance_filter + else LLMEvaluationType.SKIP + ), + rerank_settings=search_tool_config.rerank_settings, + bypass_acl=search_tool_config.bypass_acl, + ) + + config = ProSearchConfig( + search_request=search_request, + chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), + message_id=1, + use_persistence=False, + ) + + return config, search_tool diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 8a1c65638b..3873f8dcd6 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -4,6 +4,7 @@ from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk from langchain_core.messages import ToolCall +from sqlalchemy.orm import Session from onyx.agent_search.run_graph import run_main_graph from onyx.chat.llm_response_handler import LLMResponseHandlerManager @@ -13,6 +14,7 @@ from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import PromptConfig +from onyx.chat.models import ProSearchConfig from onyx.chat.prompt_builder.build import AnswerPromptBuilder from onyx.chat.prompt_builder.build import default_build_system_message from onyx.chat.prompt_builder.build import default_build_user_message @@ -25,7 +27,6 @@ ) from onyx.chat.stream_processing.utils import map_document_id_order from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler -from onyx.context.search.models import SearchRequest from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage @@ -64,8 +65,8 @@ def __init__( return_contexts: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, - use_pro_search: bool = False, - search_request: SearchRequest | None = None, + pro_search_config: ProSearchConfig | None = None, + db_session: Session | None = None, ) -> None: if single_message_history and message_history: raise ValueError( @@ -111,8 +112,10 @@ def __init__( and not skip_explicit_tool_calling ) - self.use_pro_search = use_pro_search - self.pro_search_request = search_request + self.pro_search_config = pro_search_config + if db_session is None: + raise ValueError("db_session must be provided") + self.db_session = db_session def _get_tools_list(self) -> list[Tool]: if not self.force_use_tool.force_use: @@ -259,8 +262,8 @@ def processed_streamed_output(self) -> AnswerStream: yield from self._processed_stream return - if self.use_pro_search: - if self.pro_search_request is None: + if self.pro_search_config: + if self.pro_search_config.search_request is None: raise ValueError("Search request must be provided for pro search") search_tools = [tool for tool in self.tools if isinstance(tool, SearchTool)] if len(search_tools) == 0: @@ -271,10 +274,11 @@ def processed_streamed_output(self) -> AnswerStream: search_tool = search_tools[0] yield from run_main_graph( - search_request=self.pro_search_request, + config=self.pro_search_config, primary_llm=self.llm, fast_llm=self.fast_llm, search_tool=search_tool, + db_session=self.db_session, ) return diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 91a5689d75..8e1beb3bc7 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -4,6 +4,7 @@ from enum import Enum from typing import Any from typing import TYPE_CHECKING +from uuid import UUID from pydantic import BaseModel from pydantic import ConfigDict @@ -16,6 +17,7 @@ from onyx.context.search.enums import RecencyBiasSetting from onyx.context.search.enums import SearchType from onyx.context.search.models import RetrievalDocs +from onyx.context.search.models import SearchRequest from onyx.llm.override_models import PromptOverride from onyx.tools.models import ToolCallFinalResult from onyx.tools.models import ToolCallKickoff @@ -204,6 +206,22 @@ class PersonaOverrideConfig(BaseModel): custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list) +class ProSearchConfig(BaseModel): + """ + Configuration for the Pro Search feature. + """ + + # For persisting agent search data + chat_session_id: UUID | None = None + # The message ID of the user message that triggered the Pro Search + message_id: int | None = None + # The search request that was used to generate the Pro Search + search_request: SearchRequest + + # Whether to persistence data for the Pro Search (turned off for testing) + use_persistence: bool = True + + AnswerQuestionPossibleReturn = ( OnyxAnswerPiece | CitationInfo @@ -331,19 +349,23 @@ def from_model( class SubQuery(BaseModel): sub_query: str + sub_question_id: int class SubAnswer(BaseModel): sub_answer: str + sub_question_id: int class SubQuestion(BaseModel): - question_id: str + question_id: int sub_question: str ProSearchPacket = SubQuestion | SubAnswer | SubQuery -AnswerPacket = AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse +AnswerPacket = ( + AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse +) AnswerStream = Iterator[AnswerPacket] diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 96820b7c3d..f0060e6f27 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -24,6 +24,7 @@ from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig +from onyx.chat.models import ProSearchConfig from onyx.chat.models import QADocsResponse from onyx.chat.models import StreamingError from onyx.chat.models import StreamStopInfo @@ -686,6 +687,7 @@ def stream_chat_message_objects( tools.extend(tool_list) search_request = None + pro_search_config = None if new_msg_req.use_pro_search: search_request = SearchRequest( query=final_msg.message, @@ -710,6 +712,15 @@ def stream_chat_message_objects( else None ), ) + pro_search_config = ( + ProSearchConfig( + search_request=search_request, + chat_session_id=chat_session_id, + message_id=user_message.id if user_message else None, + ) + if new_msg_req.use_pro_search + else None + ) # TODO: add previous messages, answer style config, tools, etc. # LLM prompt building, response capturing, etc. @@ -737,8 +748,8 @@ def stream_chat_message_objects( ], tools=tools, force_use_tool=_get_force_search_settings(new_msg_req, tools), - search_request=search_request, - use_pro_search=new_msg_req.use_pro_search, + pro_search_config=pro_search_config, + db_session=db_session, ) reference_db_search_docs = None diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 353eba9c9d..b4a65d818c 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -295,6 +295,17 @@ class ChatMessage__SearchDoc(Base): ) +class SubQuery__SearchDoc(Base): + __tablename__ = "sub_query__search_doc" + + sub_query_id: Mapped[int] = mapped_column( + ForeignKey("sub_query.id"), primary_key=True + ) + search_doc_id: Mapped[int] = mapped_column( + ForeignKey("search_doc.id"), primary_key=True + ) + + class Document__Tag(Base): __tablename__ = "document__tag" @@ -924,6 +935,11 @@ class SearchDoc(Base): secondary=ChatMessage__SearchDoc.__table__, back_populates="search_docs", ) + sub_queries = relationship( + "SubQuery", + secondary=SubQuery__SearchDoc.__table__, + back_populates="search_docs", + ) class ToolCall(Base): @@ -1118,6 +1134,64 @@ def __lt__(self, other: Any) -> bool: return self.display_priority < other.display_priority +class SubQuestion(Base): + """ + A sub-question is a question that is asked of the LLM to gather supporting + information to answer a primary question. + """ + + __tablename__ = "sub_question" + + id: Mapped[int] = mapped_column(primary_key=True) + primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_session_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("chat_session.id") + ) + sub_question: Mapped[str] = mapped_column(Text) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + answer: Mapped[str] = mapped_column(Text) + + # Relationships + primary_message: Mapped["ChatMessage"] = relationship( + "ChatMessage", foreign_keys=[primary_question_id] + ) + chat_session: Mapped["ChatSession"] = relationship("ChatSession") + sub_queries: Mapped[list["SubQuery"]] = relationship( + "SubQuery", back_populates="parent_question" + ) + + +class SubQuery(Base): + """ + A sub-query is a vector DB query that gathers supporting information to answer a sub-question. + """ + + __tablename__ = "sub_query" + + id: Mapped[int] = mapped_column(primary_key=True) + parent_question_id: Mapped[int] = mapped_column(ForeignKey("sub_question.id")) + chat_session_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("chat_session.id") + ) + sub_query: Mapped[str] = mapped_column(Text) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + # Relationships + parent_question: Mapped["SubQuestion"] = relationship( + "SubQuestion", back_populates="sub_queries" + ) + chat_session: Mapped["ChatSession"] = relationship("ChatSession") + search_docs: Mapped[list["SearchDoc"]] = relationship( + "SearchDoc", + secondary=SubQuery__SearchDoc.__table__, + back_populates="sub_queries", + ) + + """ Feedback, Logging, Metrics Tables """ From d35aa1eab9a5843f59c8c7b28d5ac1934336e009 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sun, 5 Jan 2025 11:05:22 -0800 Subject: [PATCH 49/78] question numbers & expanded_retrieval results --- backend/onyx/agent_search/answer_question/models.py | 4 +++- .../onyx/agent_search/answer_question/nodes/format_answer.py | 3 ++- backend/onyx/agent_search/answer_question/states.py | 2 +- backend/onyx/agent_search/main/edges.py | 4 ++-- backend/onyx/agent_search/main/models.py | 1 + backend/onyx/agent_search/main/nodes.py | 5 +++-- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_question/models.py index ea9fb8f971..69e6facf0a 100644 --- a/backend/onyx/agent_search/answer_question/models.py +++ b/backend/onyx/agent_search/answer_question/models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel +from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.context.search.models import InferenceSection @@ -12,8 +13,9 @@ class AnswerRetrievalStats(BaseModel): class QuestionAnswerResults(BaseModel): question: str + question_nr: str answer: str quality: str - # expanded_retrieval_results: list[QueryResult] + expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] sub_question_retrieval_stats: AgentChunkStats diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 902e0d4924..8cd0b8e771 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -18,9 +18,10 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: answer_results=[ QuestionAnswerResults( question=state["question"], + question_nr=state["question_nr"], quality=state.get("answer_quality", "No"), answer=state["answer"], - # expanded_retrieval_results=state["expanded_retrieval_results"], + expanded_retrieval_results=state["expanded_retrieval_results"], documents=state["documents"], sub_question_retrieval_stats=state["sub_question_retrieval_stats"], ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index b5b9f0880d..a880bfe9eb 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -31,7 +31,7 @@ class RetrievalIngestionUpdate(TypedDict): class AnswerQuestionInput(SubgraphCoreState): question: str - question_nr: int + question_nr: str ## Graph State diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 0179809463..0538c439a0 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -23,7 +23,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question, - question_nr=question_nr, + question_nr="0_" + str(question_nr), ), ) for question_nr, question in enumerate(state["initial_decomp_questions"]) @@ -72,7 +72,7 @@ def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashab AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question_data.sub_question, - question_nr=question_nr, + question_nr="1_" + str(question_nr), ), ) for question_nr, question_data in state["follow_up_sub_questions"].items() diff --git a/backend/onyx/agent_search/main/models.py b/backend/onyx/agent_search/main/models.py index c59816a10b..03821cbb6e 100644 --- a/backend/onyx/agent_search/main/models.py +++ b/backend/onyx/agent_search/main/models.py @@ -29,6 +29,7 @@ class EntityRelationshipTermExtraction(BaseModel): class FollowUpSubQuestion(BaseModel): sub_question: str + sub_question_nr: str verified: bool answered: bool answer: str diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index f2663a5ef8..3add02f06c 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -515,9 +515,9 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: initial_good_sub_questions ) elif len(new_revised_good_sub_questions) > 0: - revision_question_efficiency: float = 10.0 + revision_question_efficiency = 10.0 else: - revision_question_efficiency: float = 1.0 + revision_question_efficiency = 1.0 sub_question_answer_str = "\n\n------\n\n".join(list(set(good_qa_list))) @@ -702,6 +702,7 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: ): follow_up_sub_question = FollowUpSubQuestion( sub_question=sub_question_dict["sub_question"], + sub_question_nr="1_" + str(sub_question_nr), verified=False, answered=False, answer="", From dd2c9425bd49b94cae8299791e50a286460c5b18 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Sun, 5 Jan 2025 13:23:38 -0800 Subject: [PATCH 50/78] new columns --- .../1adf5ea20d2b_agent_doc_result_col.py | 29 +++++++++++++++++++ ...ed_create_pro_search_persistence_tables.py | 1 + .../graph_builder.py | 1 + backend/onyx/agent_search/db_operations.py | 2 ++ backend/onyx/agent_search/run_graph.py | 8 +++-- backend/onyx/db/models.py | 3 +- 6 files changed, 41 insertions(+), 3 deletions(-) create mode 100644 backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py diff --git a/backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py b/backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py new file mode 100644 index 0000000000..62db727f97 --- /dev/null +++ b/backend/alembic/versions/1adf5ea20d2b_agent_doc_result_col.py @@ -0,0 +1,29 @@ +"""agent_doc_result_col + +Revision ID: 1adf5ea20d2b +Revises: e9cf2bd7baed +Create Date: 2025-01-05 13:14:58.344316 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "1adf5ea20d2b" +down_revision = "e9cf2bd7baed" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add the new column with JSONB type + op.add_column( + "sub_question", + sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True), + ) + + +def downgrade() -> None: + # Drop the column + op.drop_column("sub_question", "sub_question_doc_results") diff --git a/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py b/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py index 2c7821c93f..1b56d0c91a 100644 --- a/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py +++ b/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py @@ -30,6 +30,7 @@ def upgrade() -> None: sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.func.now() ), + sa.Column("sub_answer", sa.Text), ) # Create sub_query table diff --git a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py index 6f1f77976b..1f82cc3f5f 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py @@ -96,6 +96,7 @@ def answer_follow_up_query_graph_builder() -> StateGraph: with get_session_context_manager() as db_session: inputs = AnswerQuestionInput( question="what can you do with onyx?", + question_nr="0", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/db_operations.py b/backend/onyx/agent_search/db_operations.py index 999fd40d5f..34bb77cc6c 100644 --- a/backend/onyx/agent_search/db_operations.py +++ b/backend/onyx/agent_search/db_operations.py @@ -11,12 +11,14 @@ def create_sub_question( chat_session_id: UUID, primary_message_id: int, sub_question: str, + sub_answer: str, ) -> SubQuestion: """Create a new sub-question record in the database.""" sub_q = SubQuestion( chat_session_id=chat_session_id, primary_question_id=primary_message_id, sub_question=sub_question, + sub_answer=sub_answer, ) db_session.add(sub_q) db_session.flush() diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 5c948f30c3..feffb421cf 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -17,6 +17,7 @@ from onyx.chat.models import SubQuestion from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager +from onyx.db.persona import get_persona_by_id from onyx.llm.interfaces import LLM from onyx.tools.models import ToolResponse from onyx.tools.tool_implementations.search.search_tool import SearchTool @@ -143,15 +144,18 @@ def run_main_graph( compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="what can you do with gitlab?", + # query="what can you do with gitlab?", + query="What are the guiding principles behind the development of cockroachDB?", ) # Joachim custom persona - # search_request.persona = get_persona_by_id(1, None, db_session) + with get_session_context_manager() as db_session: config, search_tool = get_test_config( db_session, primary_llm, fast_llm, search_request ) + search_request.persona = get_persona_by_id(1, None, db_session) + with open("output.txt", "w") as f: tool_responses = [] for output in run_graph( diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 99f41bf645..c2b1164d2b 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1151,7 +1151,8 @@ class SubQuestion(Base): time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) - answer: Mapped[str] = mapped_column(Text) + sub_answer: Mapped[str] = mapped_column(Text) + sub_question_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB()) # Relationships primary_message: Mapped["ChatMessage"] = relationship( From 695d07f0f90e0feb43b7ad05f5dadb59ae52c264 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Sun, 5 Jan 2025 17:37:49 -0800 Subject: [PATCH 51/78] non-duplicated work changes, plus finished v1 of backend streaming --- .../answer_follow_up_question/edges.py | 1 + .../graph_builder.py | 2 +- .../nodes/format_answer.py | 2 +- .../agent_search/answer_question/edges.py | 3 +- .../answer_question/graph_builder.py | 2 +- .../agent_search/answer_question/models.py | 2 +- .../nodes/answer_generation.py | 14 ++- .../answer_question/nodes/format_answer.py | 2 +- .../agent_search/answer_question/states.py | 4 +- .../nodes/generate_raw_search_data.py | 1 + .../agent_search/expanded_retrieval/edges.py | 1 + .../expanded_retrieval/graph_builder.py | 1 + .../agent_search/expanded_retrieval/nodes.py | 95 +++++++++++-------- .../agent_search/expanded_retrieval/states.py | 1 + backend/onyx/agent_search/main/edges.py | 5 +- backend/onyx/agent_search/main/nodes.py | 55 +++++------ backend/onyx/agent_search/run_graph.py | 14 ++- .../shared_graph_utils/prompts.py | 17 +++- .../agent_search/shared_graph_utils/utils.py | 34 +++++++ backend/onyx/chat/answer.py | 16 +++- backend/onyx/chat/models.py | 21 +++- .../search/search_tool.py | 13 ++- .../regression/answer_quality/agent_test.py | 9 +- 23 files changed, 209 insertions(+), 106 deletions(-) diff --git a/backend/onyx/agent_search/answer_follow_up_question/edges.py b/backend/onyx/agent_search/answer_follow_up_question/edges.py index e69b7a99e6..87c550cf6b 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/edges.py +++ b/backend/onyx/agent_search/answer_follow_up_question/edges.py @@ -18,6 +18,7 @@ def send_to_expanded_follow_up_retrieval(state: AnswerQuestionInput) -> Send | H ExpandedRetrievalInput( **in_subgraph_extract_core_fields(state), question=state["question"], + sub_question_id=state["question_id"], base_search=False ), ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py index 1f82cc3f5f..dfee1a6a13 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py @@ -96,7 +96,7 @@ def answer_follow_up_query_graph_builder() -> StateGraph: with get_session_context_manager() as db_session: inputs = AnswerQuestionInput( question="what can you do with onyx?", - question_nr="0", + question_id="0_0", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py index 0a07f630ff..4350f20165 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py @@ -23,7 +23,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: expanded_retrieval_results=state["expanded_retrieval_results"], documents=state["documents"], sub_question_retrieval_stats=state["sub_question_retrieval_stats"], - question_nr=state["question_nr"], + question_id=state["question_id"], ) ], ) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 441a6d75bb..569f6437c3 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -18,6 +18,7 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: ExpandedRetrievalInput( **in_subgraph_extract_core_fields(state), question=state["question"], - base_search=False + base_search=False, + sub_question_id=state["question_id"], ), ) diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index 977090f1c4..9526853c3c 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -103,7 +103,7 @@ def answer_query_graph_builder() -> StateGraph: subgraph_config=pro_search_config, subgraph_search_tool=search_tool, subgraph_db_session=db_session, - question_nr="0", # TODO does this make sense? doesn't matter too much + question_id="0_0", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_question/models.py index 69e6facf0a..60bda54fc9 100644 --- a/backend/onyx/agent_search/answer_question/models.py +++ b/backend/onyx/agent_search/answer_question/models.py @@ -13,7 +13,7 @@ class AnswerRetrievalStats(BaseModel): class QuestionAnswerResults(BaseModel): question: str - question_nr: str + question_id: str answer: str quality: str expanded_retrieval_results: list[QueryResult] diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 5e56df2535..b446bd0efb 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -11,6 +11,7 @@ from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.chat.models import SubAnswer from onyx.utils.logger import setup_logger logger = setup_logger() @@ -53,11 +54,20 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: for message in fast_llm.stream( prompt=msg, ): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) dispatch_custom_event( "sub_answers", - message.content, + SubAnswer( + sub_answer=content, + sub_question_id=state["question_id"], + ), ) - response.append(message.content) + response.append(content) answer_str = merge_message_runs(response, chunk_separator="")[0].content diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 8cd0b8e771..3bd1e9569f 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -18,7 +18,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: answer_results=[ QuestionAnswerResults( question=state["question"], - question_nr=state["question_nr"], + question_id=state["question_id"], quality=state.get("answer_quality", "No"), answer=state["answer"], expanded_retrieval_results=state["expanded_retrieval_results"], diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index a880bfe9eb..80bdaa80f4 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -31,7 +31,9 @@ class RetrievalIngestionUpdate(TypedDict): class AnswerQuestionInput(SubgraphCoreState): question: str - question_nr: str + question_id: str # 0_0 is original question, everything else is _. + # level 0 is original question and first decomposition, level 1 is follow up, etc + # question_num is a unique number per original question per level. ## Graph State diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index c9fba2b6ab..cd9c003f47 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -15,4 +15,5 @@ def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: question=state["config"].search_request.query, base_search=True, subgraph_search_tool=state["search_tool"], + sub_question_id=None, # This graph is always and only used for the original question ) diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 3d3b06374d..73a0fc43ef 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -19,6 +19,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab question=question, **in_subgraph_extract_core_fields(state), base_search=False, + sub_question_id=state.get("sub_question_id"), ), ) for query in query_expansions diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index bb5a7a84e9..e5988c7bcb 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -108,6 +108,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: subgraph_db_session=db_session, subgraph_config=pro_search_config, subgraph_search_tool=search_tool, + sub_question_id=None, ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 7a00c8a318..04a3221e71 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any +from collections.abc import Callable from typing import cast from typing import Literal @@ -28,17 +28,34 @@ from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from onyx.agent_search.shared_graph_utils.utils import dispatch_separated +from onyx.agent_search.shared_graph_utils.utils import make_question_id +from onyx.chat.models import ExtendedToolResponse +from onyx.chat.models import SubQuery from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_STATS from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import retrieval_preprocessing -from onyx.context.search.pipeline import search_postprocessing +from onyx.context.search.postprocessing.postprocessing import rerank_sections from onyx.llm.interfaces import LLM from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def dispatch_subquery(subquestion_id: str) -> Callable[[str, int], None]: + def helper(token: str, num: int) -> None: + dispatch_custom_event( + "subqueries", + SubQuery(sub_query=token, sub_question_id=subquestion_id, query_id=num), + ) + + return helper def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: @@ -50,44 +67,25 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: state["subgraph_db_session"] chat_session_id = state["subgraph_config"].chat_session_id sub_question_id = state.get("sub_question_id") + if sub_question_id is None: + sub_question_id = make_question_id(0, 0) # 0_0 for original question if chat_session_id is None: raise ValueError("chat_session_id must be provided for agent search") - if sub_question_id is None: - if state["subgraph_config"].use_persistence: - # in this case, we are doing retrieval on the original question. - # to make all the logic consistent (i.e. all subqueries have a - # subquestion as a parent), we create a new sub-question - # with the same content as the original question. - # if state["subgraph_config"].message_id is None: - # raise ValueError("message_id must be provided for agent search with persistence") - # sub_question_id = create_sub_question(db_session, - # chat_session_id, - # state["subgraph_config"].message_id, - # question).id - pass - else: - sub_question_id = 1 - msg = [ HumanMessage( content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), ) ] - llm_response_list: list[str | list[str | dict[str, Any]]] = [] - for message in llm.stream( - prompt=msg, - ): - dispatch_custom_event( - "subqueries", - message.content, - ) - llm_response_list.append(message.content) + + llm_response_list = dispatch_separated( + llm.stream(prompt=msg), dispatch_subquery(sub_question_id) + ) llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - rewritten_queries = llm_response.split("--") + rewritten_queries = llm_response.split("\n") if state["subgraph_config"].use_persistence: # Persist sub-queries to database @@ -123,14 +121,26 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: search_tool = state["subgraph_search_tool"] retrieved_docs: list[InferenceSection] = [] - for tool_response in search_tool.run(query=query_to_retrieve): + if not query_to_retrieve.strip(): + logger.warning("Empty query, skipping retrieval") + return DocRetrievalUpdate( + expanded_retrieval_results=[], + retrieved_documents=[], + ) + for tool_response in search_tool.run( + query=query_to_retrieve, force_no_rerank="True" + ): if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: retrieved_docs = cast( list[InferenceSection], tool_response.response.top_sections ) dispatch_custom_event( "tool_response", - tool_response, + ExtendedToolResponse( + id=tool_response.id, + sub_question_id=state["sub_question_id"] or make_question_id(0, 0), + response=tool_response.response, + ), ) retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] @@ -167,6 +177,7 @@ def verification_kickoff( verification_question = state.get( "question", state["subgraph_config"].search_request.query ) + sub_question_id = state.get("sub_question_id") return Command( update={}, goto=[ @@ -176,6 +187,7 @@ def verification_kickoff( doc_to_verify=doc, question=verification_question, base_search=False, + sub_question_id=sub_question_id, **in_subgraph_extract_core_fields(state), ), ) @@ -229,20 +241,25 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: question = state.get("question", state["subgraph_config"].search_request.query) _search_query = retrieval_preprocessing( search_request=SearchRequest(query=question), - user=None, + user=state["subgraph_search_tool"].user, # bit of a hack llm=state["subgraph_fast_llm"], db_session=state["subgraph_db_session"], ) - reranked_documents = list( - search_postprocessing( - search_query=_search_query, - retrieved_sections=verified_documents, - llm=state["subgraph_fast_llm"], + # skip section filtering + + if ( + _search_query.rerank_settings + and _search_query.rerank_settings.rerank_model_name + and _search_query.rerank_settings.num_rerank > 0 + ): + reranked_documents = rerank_sections( + _search_query, + verified_documents, ) - )[ - 0 - ] # only get the reranked szections, not the SectionRelevancePiece + else: + logger.warning("No reranking settings found, using unranked documents") + reranked_documents = verified_documents if AGENT_RERANKING_STATS: fit_scores = get_fit_scores(verified_documents, reranked_documents) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 2129c22eda..cfed4cc78a 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -18,6 +18,7 @@ class ExpandedRetrievalInput(SubgraphCoreState): question: str base_search: bool + sub_question_id: str | None ## Update/Return States diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 2ab78a80a3..7c0ac5c110 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -8,6 +8,7 @@ from onyx.agent_search.core_state import extract_core_fields_for_subgraph from onyx.agent_search.main.states import MainState from onyx.agent_search.main.states import RequireRefinedAnswerUpdate +from onyx.agent_search.shared_graph_utils.utils import make_question_id from onyx.utils.logger import setup_logger logger = setup_logger() @@ -31,7 +32,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question, - question_nr="0_" + str(question_nr), + question_id=make_question_id(0, question_nr), ), ) for question_nr, question in enumerate(state["initial_decomp_questions"]) @@ -66,7 +67,7 @@ def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashab AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question_data.sub_question, - question_nr="1_" + str(question_nr), + question_id=make_question_id(1, question_nr), ), ) for question_nr, question_data in state["follow_up_sub_questions"].items() diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 9bad6d2768..3204e550a7 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -1,5 +1,6 @@ import json import re +from collections.abc import Callable from datetime import datetime from typing import Any from typing import cast @@ -56,9 +57,11 @@ REVISED_RAG_PROMPT_NO_SUB_QUESTIONS, ) from onyx.agent_search.shared_graph_utils.prompts import SUB_QUESTION_ANSWER_TEMPLATE +from onyx.agent_search.shared_graph_utils.utils import dispatch_separated from onyx.agent_search.shared_graph_utils.utils import format_docs from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt +from onyx.agent_search.shared_graph_utils.utils import make_question_id from onyx.chat.models import SubQuestion from onyx.db.chat import log_agent_metrics from onyx.utils.logger import setup_logger @@ -66,14 +69,19 @@ logger = setup_logger() -def dispatch_subquestion(sub_question_part: str, subq_id: int) -> None: - dispatch_custom_event( - "decomp_qs", - SubQuestion( - sub_question=sub_question_part, - question_id=subq_id, - ), - ) +def dispatch_subquestion(level: int) -> Callable[[str, int], None]: + def helper(sub_question_part: str, num: int) -> None: + dispatch_custom_event( + "decomp_qs", + SubQuestion( + sub_question=sub_question_part, + question_id=make_question_id( + level, num + 1 + ), # question 0 reserved for original question if used + ), + ) + + return helper def main_decomp_base(state: MainState) -> BaseDecompUpdate: @@ -95,19 +103,9 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: # Get the rewritten queries in a defined format model = state["fast_llm"] - streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] - subq_id = 1 - for message in model.stream(msg): - content = cast(str, message.content) - if "\n" in content: - for sub_question_part in content.split("\n"): - dispatch_subquestion(sub_question_part, subq_id) - subq_id += 1 - subq_id -= 1 # fencepost; extra increment at end of loop - else: - dispatch_subquestion(content, subq_id) - streamed_tokens.append(content) + # dispatches custom events for subquestion tokens, adding in subquestion ids. + streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0)) response = merge_content(*streamed_tokens) @@ -747,21 +745,20 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: # Grader model = state["fast_llm"] - response = model.invoke(msg) - if isinstance(response.content, str): - cleaned_response = re.sub(r"```json\n|\n```", "", response.content) - parsed_response = json.loads(cleaned_response) + streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1)) + response = merge_content(*streamed_tokens) + + if isinstance(response, str): + parsed_response = response.split("\n") else: raise ValueError("LLM response is not a string") follow_up_sub_question_dict = {} - for sub_question_nr, sub_question_dict in enumerate( - parsed_response["sub_questions"] - ): + for sub_question_nr, sub_question in enumerate(parsed_response): follow_up_sub_question = FollowUpSubQuestion( - sub_question=sub_question_dict["sub_question"], - sub_question_nr="1_" + str(sub_question_nr), + sub_question=sub_question, + sub_question_nr=make_question_id(1, sub_question_nr), verified=False, answered=False, answer="", diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index feffb421cf..416bf5ea2f 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -14,12 +14,13 @@ from onyx.chat.models import AnswerStream from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import ProSearchConfig +from onyx.chat.models import SubAnswer +from onyx.chat.models import SubQuery from onyx.chat.models import SubQuestion +from onyx.chat.models import ToolResponse from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager -from onyx.db.persona import get_persona_by_id from onyx.llm.interfaces import LLM -from onyx.tools.models import ToolResponse from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_runner import ToolCallKickoff from onyx.utils.logger import setup_logger @@ -42,14 +43,11 @@ def _parse_agent_event( if event_type == "on_custom_event": # TODO: different AnswerStream types for different events if event["name"] == "decomp_qs": - # return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) return cast(SubQuestion, event["data"]) elif event["name"] == "subqueries": - # return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) - return None + return cast(SubQuery, event["data"]) elif event["name"] == "sub_answers": - # return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) - return None + return cast(SubAnswer, event["data"]) elif event["name"] == "main_answer": return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) elif event["name"] == "tool_response": @@ -154,7 +152,7 @@ def run_main_graph( db_session, primary_llm, fast_llm, search_request ) - search_request.persona = get_persona_by_id(1, None, db_session) + # search_request.persona = get_persona_by_id(1, None, db_session) with open("output.txt", "w") as f: tool_responses = [] diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 707f1f41c1..83d29fcc00 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -7,7 +7,11 @@ \n ------- \n {question} \n ------- \n - Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """ + Formulate the queries separated by newlines (Do not say 'Query 1: ...', just write the querytext) as follows: + + +... + queries: """ REWRITE_PROMPT_MULTI = """ \n Please create a list of 2-3 sample documents that could answer an original question. Each document @@ -42,7 +46,7 @@ BASE_RAG_PROMPT_v2 = """ \n Use the context provided below - and only the - provided context - to answer the given question. (Note that the answer is in service of anserwing a broader + provided context - to answer the given question. (Note that the answer is in service of answering a broader question, given below as 'motivation'.) Again, only use the provided context and do not use your internal knowledge! If you cannot answer the @@ -315,10 +319,13 @@ sub-questions or those that already were suggested and failed. In other words - what can we try in addition to what has been tried so far? - Generate the list of json dictionaries with the following format: + Generate the list of questions separated by new lines like this: - {{"sub_questions": [{{"sub_question": }}, - ...]}} """ + + + + ... + """ DECOMPOSE_PROMPT = """ \n For an initial user question, please generate at 5-10 individual sub-questions whose answers would help diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index 6ff0c691d3..fa03e9430a 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -1,12 +1,16 @@ import ast import json import re +from collections.abc import Callable +from collections.abc import Iterator from collections.abc import Sequence from datetime import datetime from datetime import timedelta from typing import Any +from typing import cast from uuid import UUID +from langchain_core.messages import BaseMessage from sqlalchemy.orm import Session from onyx.agent_search.main.models import EntityRelationshipTermExtraction @@ -200,3 +204,33 @@ def get_persona_prompt(persona: Persona | None) -> str: return "" else: return "\n".join([x.system_prompt for x in persona.prompts]) + + +def make_question_id(level: int, question_nr: int) -> str: + return f"{level}_{question_nr}" + + +def parse_question_id(question_id: str) -> tuple[int, int]: + level, question_nr = question_id.split("_") + return int(level), int(question_nr) + + +def dispatch_separated( + token_itr: Iterator[BaseMessage], + dispatch_event: Callable[[str, int], None], + sep: str = "\n", +) -> list[str | list[str | dict[str, Any]]]: + num = 0 + streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] + for message in token_itr: + content = cast(str, message.content) + if sep in content: + for sub_question_part in content.split(sep): + dispatch_event(sub_question_part, num) + num += 1 + num -= 1 # fencepost; extra increment at end of loop + else: + dispatch_event(content, num) + streamed_tokens.append(content) + + return streamed_tokens diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 3873f8dcd6..68d2cb5855 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -46,7 +46,6 @@ def __init__( question: str, answer_style_config: AnswerStyleConfig, llm: LLM, - fast_llm: LLM, prompt_config: PromptConfig, force_use_tool: ForceUseTool, # must be the same length as `docs`. If None, all docs are considered "relevant" @@ -66,6 +65,7 @@ def __init__( skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, pro_search_config: ProSearchConfig | None = None, + fast_llm: LLM | None = None, db_session: Session | None = None, ) -> None: if single_message_history and message_history: @@ -113,8 +113,6 @@ def __init__( ) self.pro_search_config = pro_search_config - if db_session is None: - raise ValueError("db_session must be provided") self.db_session = db_session def _get_tools_list(self) -> list[Tool]: @@ -273,13 +271,21 @@ def processed_streamed_output(self) -> AnswerStream: raise ValueError("Multiple search tools found") search_tool = search_tools[0] - yield from run_main_graph( + processed_stream = [] + if self.db_session is None: + raise ValueError("db_session must be provided for pro search") + if self.fast_llm is None: + raise ValueError("fast_llm must be provided for pro search") + for packet in run_main_graph( config=self.pro_search_config, primary_llm=self.llm, fast_llm=self.fast_llm, search_tool=search_tool, db_session=self.db_session, - ) + ): + processed_stream.append(packet) + yield packet + self._processed_stream = processed_stream return prompt_builder = AnswerPromptBuilder( diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 8e1beb3bc7..233d998211 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -349,20 +349,33 @@ def from_model( class SubQuery(BaseModel): sub_query: str - sub_question_id: int + sub_question_id: str # _ + query_id: int + + @model_validator(mode="after") + def check_sub_question_id(self) -> "SubQuery": + if len(self.sub_question_id.split("_")) != 2: + raise ValueError( + "sub_question_id must be in the format _" + ) + return self class SubAnswer(BaseModel): sub_answer: str - sub_question_id: int + sub_question_id: str # _ class SubQuestion(BaseModel): - question_id: int + question_id: str # _ sub_question: str -ProSearchPacket = SubQuestion | SubAnswer | SubQuery +class ExtendedToolResponse(ToolResponse): + sub_question_id: str # _ + + +ProSearchPacket = SubQuestion | SubAnswer | SubQuery | ExtendedToolResponse AnswerPacket = ( AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 7dcfde4d76..5b2fe4bebc 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -285,6 +285,8 @@ def _build_response_for_specified_sections( def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["query"]) + # kind of awkward to require this to be str, but it's "True" or "False" + force_no_rerank = kwargs.get("force_no_rerank", "False") if self.selected_sections: yield from self._build_response_for_specified_sections(query) @@ -293,7 +295,9 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, - evaluation_type=self.evaluation_type, + evaluation_type=LLMEvaluationType.SKIP + if force_no_rerank == "True" + else self.evaluation_type, human_selected_filters=( self.retrieval_options.filters if self.retrieval_options else None ), @@ -302,7 +306,9 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: self.retrieval_options.offset if self.retrieval_options else None ), limit=self.retrieval_options.limit if self.retrieval_options else None, - rerank_settings=self.rerank_settings, + rerank_settings=None + if force_no_rerank == "True" + else self.rerank_settings, chunks_above=self.chunks_above, chunks_below=self.chunks_below, full_doc=self.full_doc, @@ -319,6 +325,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: db_session=self.db_session, prompt_config=self.prompt_config, ) + self.search_pipeline = search_pipeline # used for agent_search metrics yield ToolResponse( id=SEARCH_RESPONSE_SUMMARY_ID, @@ -332,8 +339,6 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: ), ) - self.search_pipeline = search_pipeline - yield ToolResponse( id=SEARCH_DOC_CONTENT_ID, response=OnyxContexts( diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test.py index d3087c8a6f..c3bcd95f38 100644 --- a/backend/tests/regression/answer_quality/agent_test.py +++ b/backend/tests/regression/answer_quality/agent_test.py @@ -7,6 +7,7 @@ from onyx.agent_search.main.graph_builder import main_graph_builder from onyx.agent_search.main.states import MainInput +from onyx.chat.models import ProSearchConfig from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager from onyx.llm.factory import get_default_llms @@ -49,8 +50,14 @@ num_target_sub_questions = len(target_sub_questions) search_request = SearchRequest(query=example_question) - inputs = MainInput( + config = ProSearchConfig( search_request=search_request, + message_id=None, + chat_session_id=None, + use_persistence=False, + ) + inputs = MainInput( + config=config, primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, From 2f7f4917e30521a4daf38043cb4d0fa6f3fd33de Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 6 Jan 2025 08:00:30 -0800 Subject: [PATCH 52/78] persistence sans Citations TODO: change commit to flush --- ...18a25b_agent_table_changes_rename_level.py | 40 +++++++++ .../nodes/answer_generation.py | 17 ---- backend/onyx/agent_search/main/nodes.py | 25 ++++++ backend/onyx/agent_search/run_graph.py | 27 ++++--- .../agent_search/shared_graph_utils/utils.py | 5 +- backend/onyx/db/chat.py | 81 ++++++++++++++++++- backend/onyx/db/models.py | 2 + 7 files changed, 164 insertions(+), 33 deletions(-) create mode 100644 backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py diff --git a/backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py b/backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py new file mode 100644 index 0000000000..e845380991 --- /dev/null +++ b/backend/alembic/versions/c0132518a25b_agent_table_changes_rename_level.py @@ -0,0 +1,40 @@ +"""agent_table_changes_rename_level + +Revision ID: c0132518a25b +Revises: 1adf5ea20d2b +Create Date: 2025-01-05 16:38:37.660152 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "c0132518a25b" +down_revision = "1adf5ea20d2b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add level and level_question_nr columns with NOT NULL constraint + op.add_column( + "sub_question", + sa.Column("level", sa.Integer(), nullable=False, server_default="0"), + ) + op.add_column( + "sub_question", + sa.Column( + "level_question_nr", sa.Integer(), nullable=False, server_default="0" + ), + ) + + # Remove the server_default after the columns are created + op.alter_column("sub_question", "level", server_default=None) + op.alter_column("sub_question", "level_question_nr", server_default=None) + + +def downgrade() -> None: + # Remove the columns + op.drop_column("sub_question", "level_question_nr") + op.drop_column("sub_question", "level") diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 5e56df2535..49d2f64521 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -61,23 +61,6 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: answer_str = merge_message_runs(response, chunk_separator="")[0].content - if state["subgraph_config"].use_persistence: - # Persist the sub-answer in the database - # db_session = state["subgraph_db_session"] - # chat_session_id = state["subgraph_config"].chat_session_id - # primary_message_id = state["subgraph_config"].message_id - # sub_question_id = state["sub_question_id"] - - # if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None: - # create_sub_answer( - # db_session=db_session, - # chat_session_id=chat_session_id, - # primary_message_id=primary_message_id, - # sub_question_id=sub_question_id, - # answer=answer_str, - # ) - pass - return QAGenerationUpdate( answer=answer_str, ) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 9bad6d2768..eb4c8bf134 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -61,6 +61,7 @@ from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.chat.models import SubQuestion from onyx.db.chat import log_agent_metrics +from onyx.db.chat import log_agent_sub_question_results from onyx.utils.logger import setup_logger logger = setup_logger() @@ -858,6 +859,30 @@ def logging_node(state: MainState) -> MainOutput: agent_metrics=combined_agent_metrics, ) + if state["config"].use_persistence: + # Persist the sub-answer in the database + db_session = state["db_session"] + chat_session_id = state["config"].chat_session_id + primary_message_id = state["config"].message_id + sub_question_answer_results = state["follow_up_decomp_answer_results"] + + log_agent_sub_question_results( + db_session=db_session, + chat_session_id=chat_session_id, + primary_message_id=primary_message_id, + sub_question_answer_results=sub_question_answer_results, + ) + + # if chat_session_id is not None and primary_message_id is not None and sub_question_id is not None: + # create_sub_answer( + # db_session=db_session, + # chat_session_id=chat_session_id, + # primary_message_id=primary_message_id, + # sub_question_id=sub_question_id, + # answer=answer_str, + # # ) + # pass + main_output = MainOutput() return main_output diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index feffb421cf..14af2edc24 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -155,17 +155,18 @@ def run_main_graph( ) search_request.persona = get_persona_by_id(1, None, db_session) + config.use_persistence = True - with open("output.txt", "w") as f: - tool_responses = [] - for output in run_graph( - compiled_graph, config, search_tool, primary_llm, fast_llm, db_session - ): - if isinstance(output, OnyxAnswerPiece): - f.write(str(output.answer_piece) + "|") - elif isinstance(output, ToolCallKickoff): - pass - elif isinstance(output, ToolResponse): - tool_responses.append(output) - for tool_response in tool_responses: - f.write("tool response: " + str(tool_response.response) + "\n") + # with open("output.txt", "w") as f: + tool_responses = [] + for output in run_graph( + compiled_graph, config, search_tool, primary_llm, fast_llm, db_session + ): + if isinstance(output, OnyxAnswerPiece): + tool_responses.append("|") + elif isinstance(output, ToolCallKickoff): + pass + elif isinstance(output, ToolResponse): + tool_responses.append(output.response) + for tool_response in tool_responses: + logger.info(tool_response) diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index 6ff0c691d3..f8d4a7af0c 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -187,9 +187,10 @@ def get_test_config( config = ProSearchConfig( search_request=search_request, - chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), + # chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), + chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), message_id=1, - use_persistence=False, + use_persistence=True, ) return config, search_tool diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index f432f3b444..7f763cc057 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from datetime import datetime from datetime import timedelta +from typing import Any from uuid import UUID from fastapi import HTTPException @@ -15,11 +16,13 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session +from onyx.agent_search.answer_question.models import QuestionAnswerResults from onyx.agent_search.main.models import CombinedAgentMetrics from onyx.auth.schemas import UserRole from onyx.chat.models import DocumentRelevance from onyx.configs.chat_configs import HARD_DELETE_CHATS from onyx.configs.constants import MessageType +from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDocs from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SearchDoc as ServerSearchDoc @@ -31,6 +34,8 @@ from onyx.db.models import Prompt from onyx.db.models import SearchDoc from onyx.db.models import SearchDoc as DBSearchDoc +from onyx.db.models import SubQuery +from onyx.db.models import SubQuestion from onyx.db.models import ToolCall from onyx.db.models import User from onyx.db.persona import get_best_persona_id_for_user @@ -893,6 +898,80 @@ def log_agent_metrics( ) db_session.add(agent_metric_tracking) - db_session.commit() + db_session.flush() return agent_metric_tracking + + +def log_agent_sub_question_results( + db_session: Session, + chat_session_id: UUID | None, + primary_message_id: int | None, + sub_question_answer_results: list[QuestionAnswerResults], +) -> None: + def _create_citation_format_list( + document_citations: list[InferenceSection], + ) -> list[dict[str, Any]]: + citation_list: list[dict[str, Any]] = [] + for document_citation in document_citations: + document_citation_dict = { + "link": "", + "blurb": document_citation.center_chunk.blurb, + "content": document_citation.center_chunk.content, + "metadata": document_citation.center_chunk.metadata, + "updated_at": str(document_citation.center_chunk.updated_at), + "document_id": document_citation.center_chunk.document_id, + "source_type": "file", + "source_links": document_citation.center_chunk.source_links, + "match_highlights": document_citation.center_chunk.match_highlights, + "semantic_identifier": document_citation.center_chunk.semantic_identifier, + } + + citation_list.append(document_citation_dict) + + return citation_list + + now = datetime.now() + + for sub_question_answer_result in sub_question_answer_results: + level, level_question_nr = [ + int(x) for x in sub_question_answer_result.question_nr.split("_") + ] + sub_question = sub_question_answer_result.question + sub_answer = sub_question_answer_result.answer + sub_document_results = _create_citation_format_list( + sub_question_answer_result.documents + ) + sub_queries = [ + x.query for x in sub_question_answer_result.expanded_retrieval_results + ] + + sub_question_object = SubQuestion( + chat_session_id=chat_session_id, + primary_question_id=primary_message_id, + level=level, + level_question_nr=level_question_nr, + sub_question=sub_question, + sub_answer=sub_answer, + sub_question_doc_results=sub_document_results, + ) + + db_session.add(sub_question_object) + db_session.commit() + # db_session.flush() + + sub_question_id = sub_question_object.id + + for sub_query in sub_queries: + sub_query_object = SubQuery( + parent_question_id=sub_question_id, + chat_session_id=chat_session_id, + sub_query=sub_query, + time_created=now, + ) + + db_session.add(sub_query_object) + db_session.commit() + # db_session.flush() + + return None diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index c2b1164d2b..7ed915970b 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1148,6 +1148,8 @@ class SubQuestion(Base): PGUUID(as_uuid=True), ForeignKey("chat_session.id") ) sub_question: Mapped[str] = mapped_column(Text) + level: Mapped[int] = mapped_column(Integer) + level_question_nr: Mapped[int] = mapped_column(Integer) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) From d697ad0fc8f2ff685702b8d727b7425cc18845ca Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 6 Jan 2025 11:13:42 -0800 Subject: [PATCH 53/78] renamed tables + additional logging --- ...e58_agent_metric_table_renames__agent__.py | 25 ++++ ...eb76d618ec_agent_table_renames__agent__.py | 84 ++++++++++++ backend/onyx/agent_search/db_operations.py | 24 ++-- backend/onyx/agent_search/main/nodes.py | 129 ++++++++++++++++-- backend/onyx/db/chat.py | 8 +- backend/onyx/db/models.py | 34 ++--- 6 files changed, 259 insertions(+), 45 deletions(-) create mode 100644 backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py create mode 100644 backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py diff --git a/backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py b/backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py new file mode 100644 index 0000000000..2b605f5b3d --- /dev/null +++ b/backend/alembic/versions/9787be927e58_agent_metric_table_renames__agent__.py @@ -0,0 +1,25 @@ +"""agent_metric_table_renames__agent__ + +Revision ID: 9787be927e58 +Revises: bceb76d618ec +Create Date: 2025-01-06 11:01:44.210160 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "9787be927e58" +down_revision = "bceb76d618ec" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Rename table from agent_search_metrics to agent__search_metrics + op.rename_table("agent_search_metrics", "agent__search_metrics") + + +def downgrade() -> None: + # Rename table back from agent__search_metrics to agent_search_metrics + op.rename_table("agent__search_metrics", "agent_search_metrics") diff --git a/backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py b/backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py new file mode 100644 index 0000000000..1c1cb2e0d8 --- /dev/null +++ b/backend/alembic/versions/bceb76d618ec_agent_table_renames__agent__.py @@ -0,0 +1,84 @@ +"""agent_table_renames__agent__ + +Revision ID: bceb76d618ec +Revises: c0132518a25b +Create Date: 2025-01-06 10:50:48.109285 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "bceb76d618ec" +down_revision = "c0132518a25b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_constraint( + "sub_query__search_doc_sub_query_id_fkey", + "sub_query__search_doc", + type_="foreignkey", + ) + op.drop_constraint( + "sub_query__search_doc_search_doc_id_fkey", + "sub_query__search_doc", + type_="foreignkey", + ) + # Rename tables + op.rename_table("sub_query", "agent__sub_query") + op.rename_table("sub_question", "agent__sub_question") + op.rename_table("sub_query__search_doc", "agent__sub_query__search_doc") + + # Update both foreign key constraints for agent__sub_query__search_doc + + # Create new foreign keys with updated names + op.create_foreign_key( + "agent__sub_query__search_doc_sub_query_id_fkey", + "agent__sub_query__search_doc", + "agent__sub_query", + ["sub_query_id"], + ["id"], + ) + op.create_foreign_key( + "agent__sub_query__search_doc_search_doc_id_fkey", + "agent__sub_query__search_doc", + "search_doc", # This table name doesn't change + ["search_doc_id"], + ["id"], + ) + + +def downgrade() -> None: + # Update foreign key constraints for sub_query__search_doc + op.drop_constraint( + "agent__sub_query__search_doc_sub_query_id_fkey", + "agent__sub_query__search_doc", + type_="foreignkey", + ) + op.drop_constraint( + "agent__sub_query__search_doc_search_doc_id_fkey", + "agent__sub_query__search_doc", + type_="foreignkey", + ) + + # Rename tables back + op.rename_table("agent__sub_query__search_doc", "sub_query__search_doc") + op.rename_table("agent__sub_question", "sub_question") + op.rename_table("agent__sub_query", "sub_query") + + op.create_foreign_key( + "sub_query__search_doc_sub_query_id_fkey", + "sub_query__search_doc", + "sub_query", + ["sub_query_id"], + ["id"], + ) + op.create_foreign_key( + "sub_query__search_doc_search_doc_id_fkey", + "sub_query__search_doc", + "search_doc", # This table name doesn't change + ["search_doc_id"], + ["id"], + ) diff --git a/backend/onyx/agent_search/db_operations.py b/backend/onyx/agent_search/db_operations.py index 34bb77cc6c..3df137b114 100644 --- a/backend/onyx/agent_search/db_operations.py +++ b/backend/onyx/agent_search/db_operations.py @@ -2,8 +2,8 @@ from sqlalchemy.orm import Session -from onyx.db.models import SubQuery -from onyx.db.models import SubQuestion +from onyx.db.models import AgentSubQuery +from onyx.db.models import AgentSubQuestion def create_sub_question( @@ -12,9 +12,9 @@ def create_sub_question( primary_message_id: int, sub_question: str, sub_answer: str, -) -> SubQuestion: +) -> AgentSubQuestion: """Create a new sub-question record in the database.""" - sub_q = SubQuestion( + sub_q = AgentSubQuestion( chat_session_id=chat_session_id, primary_question_id=primary_message_id, sub_question=sub_question, @@ -30,9 +30,9 @@ def create_sub_query( chat_session_id: UUID, parent_question_id: int, sub_query: str, -) -> SubQuery: +) -> AgentSubQuery: """Create a new sub-query record in the database.""" - sub_q = SubQuery( + sub_q = AgentSubQuery( chat_session_id=chat_session_id, parent_question_id=parent_question_id, sub_query=sub_query, @@ -45,11 +45,11 @@ def create_sub_query( def get_sub_questions_for_message( db_session: Session, primary_message_id: int, -) -> list[SubQuestion]: +) -> list[AgentSubQuestion]: """Get all sub-questions for a given primary message.""" return ( - db_session.query(SubQuestion) - .filter(SubQuestion.primary_question_id == primary_message_id) + db_session.query(AgentSubQuestion) + .filter(AgentSubQuestion.primary_question_id == primary_message_id) .all() ) @@ -57,10 +57,10 @@ def get_sub_questions_for_message( def get_sub_queries_for_question( db_session: Session, sub_question_id: int, -) -> list[SubQuery]: +) -> list[AgentSubQuery]: """Get all sub-queries for a given sub-question.""" return ( - db_session.query(SubQuery) - .filter(SubQuery.parent_question_id == sub_question_id) + db_session.query(AgentSubQuery) + .filter(AgentSubQuery.parent_question_id == sub_question_id) .all() ) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 4d1402b545..2d4505acca 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -86,6 +86,10 @@ def helper(sub_question_part: str, num: int) -> None: def main_decomp_base(state: MainState) -> BaseDecompUpdate: + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---BASE DECOMP START---") + question = state["config"].search_request.query state["db_session"] chat_session_id = state["config"].chat_session_id @@ -136,6 +140,10 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: # ) pass + now_end = datetime.now() + + logger.info(f"XXXXXX--{now_end}--{now_end - now_start}--XXX---BASE DECOMP END---") + return BaseDecompUpdate( initial_decomp_questions=decomp_list, agent_start_time=agent_start_time, @@ -226,7 +234,9 @@ def _calculate_initial_agent_stats( def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: - logger.info("---GENERATE INITIAL---") + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---GENERATE INITIAL---") question = state["config"].search_request.query persona_prompt = get_persona_prompt(state["config"].search_request.persona) @@ -307,15 +317,18 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: state["decomp_answer_results"], state["original_question_retrieval_stats"] ) - logger.info(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") - - logger.info(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n") + logger.info(f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n") if initial_agent_stats: logger.info(initial_agent_stats.original_question) logger.info(initial_agent_stats.sub_questions) logger.info(initial_agent_stats.agent_effectiveness) - logger.info("\n\n ---INITIAL AGENT ANSWER END---\n\n") + + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INITIAL AGENT ANSWER END---\n\n" + ) agent_base_end_time = datetime.now() @@ -362,15 +375,27 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate InitialAnswerQualityUpdate """ - logger.info("Checking for base answer validity - for not set True/False manually") + now_start = datetime.now() + + logger.info( + f"XXXXXX--{now_start}--XXX---Checking for base answer validity - for not set True/False manually" + ) verdict = True + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INITIAL ANSWER QUALITY CHECK END---" + ) + return InitialAnswerQualityUpdate(initial_answer_quality=verdict) def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: - logger.info("---GENERATE ENTITIES & TERMS---") + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---GENERATE ENTITIES & TERMS---") # first four lines duplicates from generate_initial_answer question = state["config"].search_request.query @@ -439,6 +464,12 @@ def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: ) ) + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---ENTITY TERM EXTRACTION END---" + ) + return EntityTermExtractionUpdate( entity_retlation_term_extractions=EntityRelationshipTermExtraction( entities=entities, @@ -449,7 +480,9 @@ def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: - logger.info("---GENERATE INITIAL BASE ANSWER---") + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---GENERATE INITIAL BASE ANSWER---") question = state["config"].search_request.query original_question_docs = state["all_original_question_documents"] @@ -468,17 +501,30 @@ def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: response = model.invoke(msg) answer = response.pretty_repr() + now_end = datetime.now() + logger.info( - f"\n\n---INITIAL BASE ANSWER START---\n\nBase: {answer}\n\n ---INITIAL BASE ANSWER END---\n\n" + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INITIAL BASE ANSWER END---\n\n" ) + return InitialAnswerBASEUpdate(initial_base_answer=answer) def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---INGEST ANSWERS---") documents = [] answer_results = state.get("answer_results", []) for answer_result in answer_results: documents.extend(answer_result.documents) + + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INGEST ANSWERS END---" + ) + return DecompAnswersUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here @@ -488,6 +534,10 @@ def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate: + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---INGEST INITIAL RETRIEVAL---") + sub_question_retrieval_stats = state[ "base_expanded_retrieval_result" ].sub_question_retrieval_stats @@ -496,6 +546,12 @@ def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpd else: sub_question_retrieval_stats = sub_question_retrieval_stats + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INGEST INITIAL RETRIEVAL END---" + ) + return ExpandedRetrievalUpdate( original_question_retrieval_results=state[ "base_expanded_retrieval_result" @@ -508,7 +564,15 @@ def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpd def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: - logger.info("---REFINED ANSWER DECISION---") + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---REFINED ANSWER DECISION---") + + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---REFINED ANSWER DECISION END---" + ) if False: return RequireRefinedAnswerUpdate(require_refined_answer=False) @@ -518,7 +582,9 @@ def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: - logger.info("---GENERATE REFINED ANSWER---") + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---GENERATE REFINED ANSWER---") question = state["config"].search_request.query persona_prompt = get_persona_prompt(state["config"].search_request.persona) @@ -685,7 +751,11 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}" ) - logger.info("\n\n ---INITIAL AGENT ANSWER END---\n\n") + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INITIAL AGENT ANSWER END---\n\n" + ) agent_refined_end_time = datetime.now() agent_refined_duration = ( @@ -698,6 +768,12 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: duration_s=agent_refined_duration, ) + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---REFINED ANSWER UPDATE END---" + ) + return RefinedAnswerUpdate( refined_answer=answer, refined_answer_quality=True, # TODO: replace this with the actual check value @@ -710,6 +786,10 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: """ """ + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---FOLLOW UP DECOMPOSE---") + agent_refined_start_time = datetime.now() question = state["config"].search_request.query @@ -767,6 +847,12 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: follow_up_sub_question_dict[sub_question_nr] = follow_up_sub_question + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---FOLLOW UP DECOMPOSE END---" + ) + return FollowUpSubQuestionsUpdate( follow_up_sub_questions=follow_up_sub_question_dict, agent_refined_start_time=agent_refined_start_time, @@ -776,10 +862,21 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: def ingest_follow_up_answers( state: AnswerQuestionOutput, ) -> FollowUpDecompAnswersUpdate: + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---INGEST FOLLOW UP ANSWERS---") + documents = [] answer_results = state.get("answer_results", []) for answer_result in answer_results: documents.extend(answer_result.documents) + + now_end = datetime.now() + + logger.info( + f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INGEST FOLLOW UP ANSWERS END---" + ) + return FollowUpDecompAnswersUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here @@ -789,7 +886,9 @@ def ingest_follow_up_answers( def logging_node(state: MainState) -> MainOutput: - logger.info("---LOGGING NODE---") + now_start = datetime.now() + + logger.info(f"XXXXXX--{now_start}--XXX---LOGGING NODE---") agent_start_time = state["agent_start_time"] agent_base_end_time = state["agent_base_end_time"] @@ -882,4 +981,8 @@ def logging_node(state: MainState) -> MainOutput: main_output = MainOutput() + now_end = datetime.now() + + logger.info(f"XXXXXX--{now_end}--{now_end - now_start}--XXX---LOGGING NODE END---") + return main_output diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index 5fc5fcbe4c..45de030a1d 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -27,6 +27,8 @@ from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SearchDoc as ServerSearchDoc from onyx.db.models import AgentSearchMetrics +from onyx.db.models import AgentSubQuery +from onyx.db.models import AgentSubQuestion from onyx.db.models import ChatMessage from onyx.db.models import ChatMessage__SearchDoc from onyx.db.models import ChatSession @@ -34,8 +36,6 @@ from onyx.db.models import Prompt from onyx.db.models import SearchDoc from onyx.db.models import SearchDoc as DBSearchDoc -from onyx.db.models import SubQuery -from onyx.db.models import SubQuestion from onyx.db.models import ToolCall from onyx.db.models import User from onyx.db.persona import get_best_persona_id_for_user @@ -946,7 +946,7 @@ def _create_citation_format_list( x.query for x in sub_question_answer_result.expanded_retrieval_results ] - sub_question_object = SubQuestion( + sub_question_object = AgentSubQuestion( chat_session_id=chat_session_id, primary_question_id=primary_message_id, level=level, @@ -963,7 +963,7 @@ def _create_citation_format_list( sub_question_id = sub_question_object.id for sub_query in sub_queries: - sub_query_object = SubQuery( + sub_query_object = AgentSubQuery( parent_question_id=sub_question_id, chat_session_id=chat_session_id, sub_query=sub_query, diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 7ed915970b..1f817254ec 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -295,11 +295,11 @@ class ChatMessage__SearchDoc(Base): ) -class SubQuery__SearchDoc(Base): - __tablename__ = "sub_query__search_doc" +class AgentSubQuery__SearchDoc(Base): + __tablename__ = "agent__sub_query__search_doc" sub_query_id: Mapped[int] = mapped_column( - ForeignKey("sub_query.id"), primary_key=True + ForeignKey("agent__sub_query.id"), primary_key=True ) search_doc_id: Mapped[int] = mapped_column( ForeignKey("search_doc.id"), primary_key=True @@ -936,8 +936,8 @@ class SearchDoc(Base): back_populates="search_docs", ) sub_queries = relationship( - "SubQuery", - secondary=SubQuery__SearchDoc.__table__, + "AgentSubQuery", + secondary=AgentSubQuery__SearchDoc.__table__, back_populates="search_docs", ) @@ -1134,13 +1134,13 @@ def __lt__(self, other: Any) -> bool: return self.display_priority < other.display_priority -class SubQuestion(Base): +class AgentSubQuestion(Base): """ A sub-question is a question that is asked of the LLM to gather supporting information to answer a primary question. """ - __tablename__ = "sub_question" + __tablename__ = "agent__sub_question" id: Mapped[int] = mapped_column(primary_key=True) primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) @@ -1161,20 +1161,22 @@ class SubQuestion(Base): "ChatMessage", foreign_keys=[primary_question_id] ) chat_session: Mapped["ChatSession"] = relationship("ChatSession") - sub_queries: Mapped[list["SubQuery"]] = relationship( - "SubQuery", back_populates="parent_question" + sub_queries: Mapped[list["AgentSubQuery"]] = relationship( + "AgentSubQuery", back_populates="parent_question" ) -class SubQuery(Base): +class AgentSubQuery(Base): """ A sub-query is a vector DB query that gathers supporting information to answer a sub-question. """ - __tablename__ = "sub_query" + __tablename__ = "agent__sub_query" id: Mapped[int] = mapped_column(primary_key=True) - parent_question_id: Mapped[int] = mapped_column(ForeignKey("sub_question.id")) + parent_question_id: Mapped[int] = mapped_column( + ForeignKey("agent__sub_question.id") + ) chat_session_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("chat_session.id") ) @@ -1184,13 +1186,13 @@ class SubQuery(Base): ) # Relationships - parent_question: Mapped["SubQuestion"] = relationship( - "SubQuestion", back_populates="sub_queries" + parent_question: Mapped["AgentSubQuestion"] = relationship( + "AgentSubQuestion", back_populates="sub_queries" ) chat_session: Mapped["ChatSession"] = relationship("ChatSession") search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", - secondary=SubQuery__SearchDoc.__table__, + secondary=AgentSubQuery__SearchDoc.__table__, back_populates="sub_queries", ) @@ -1676,7 +1678,7 @@ class PGFileStore(Base): class AgentSearchMetrics(Base): - __tablename__ = "agent_search_metrics" + __tablename__ = "agent__search_metrics" id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[UUID | None] = mapped_column( From 0ff44c76615251b7dbac10b143fe2ef0047120d3 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 6 Jan 2025 11:25:34 -0800 Subject: [PATCH 54/78] agent metric column rename duration_s -> duration__s --- ...925b58bd75b6_agent_metric_col_rename__s.py | 35 +++++++++++++++++++ backend/onyx/agent_search/main/models.py | 10 +++--- backend/onyx/agent_search/main/nodes.py | 10 +++--- backend/onyx/db/chat.py | 4 +-- backend/onyx/db/models.py | 4 +-- 5 files changed, 49 insertions(+), 14 deletions(-) create mode 100644 backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py diff --git a/backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py b/backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py new file mode 100644 index 0000000000..6bf5016084 --- /dev/null +++ b/backend/alembic/versions/925b58bd75b6_agent_metric_col_rename__s.py @@ -0,0 +1,35 @@ +"""agent_metric_col_rename__s + +Revision ID: 925b58bd75b6 +Revises: 9787be927e58 +Create Date: 2025-01-06 11:20:26.752441 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "925b58bd75b6" +down_revision = "9787be927e58" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Rename columns using PostgreSQL syntax + op.alter_column( + "agent__search_metrics", "base_duration_s", new_column_name="base_duration__s" + ) + op.alter_column( + "agent__search_metrics", "full_duration_s", new_column_name="full_duration__s" + ) + + +def downgrade() -> None: + # Revert the column renames + op.alter_column( + "agent__search_metrics", "base_duration__s", new_column_name="base_duration_s" + ) + op.alter_column( + "agent__search_metrics", "full_duration__s", new_column_name="full_duration_s" + ) diff --git a/backend/onyx/agent_search/main/models.py b/backend/onyx/agent_search/main/models.py index 40bf9d97fc..8a4cb04e4b 100644 --- a/backend/onyx/agent_search/main/models.py +++ b/backend/onyx/agent_search/main/models.py @@ -36,9 +36,9 @@ class FollowUpSubQuestion(BaseModel): class AgentTimings(BaseModel): - base_duration_s: float | None - refined_duration_s: float | None - full_duration_s: float | None + base_duration__s: float | None + refined_duration__s: float | None + full_duration__s: float | None class AgentBaseMetrics(BaseModel): @@ -49,13 +49,13 @@ class AgentBaseMetrics(BaseModel): verified_avg_score_base: float | None base_doc_boost_factor: float | None support_boost_factor: float | None - duration_s: float | None + duration__s: float | None class AgentRefinedMetrics(BaseModel): refined_doc_boost_factor: float | None refined_question_boost_factor: float | None - duration_s: float | None + duration__s: float | None class AgentAdditionalMetrics(BaseModel): diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 2d4505acca..c5a554335a 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -352,7 +352,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: support_boost_factor=initial_agent_stats.agent_effectiveness.get( "support_ratio", None ), - duration_s=(agent_base_end_time - state["agent_start_time"]).total_seconds(), + duration__s=(agent_base_end_time - state["agent_start_time"]).total_seconds(), ) return InitialAnswerUpdate( @@ -765,7 +765,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: agent_refined_metrics = AgentRefinedMetrics( refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency, refined_question_boost_factor=refined_agent_stats.revision_question_efficiency, - duration_s=agent_refined_duration, + duration__s=agent_refined_duration, ) now_end = datetime.now() @@ -923,9 +923,9 @@ def logging_node(state: MainState) -> MainOutput: combined_agent_metrics = CombinedAgentMetrics( timings=AgentTimings( - base_duration_s=agent_base_duration, - refined_duration_s=agent_refined_duration, - full_duration_s=agent_full_duration, + base_duration__s=agent_base_duration, + refined_duration__s=agent_refined_duration, + full_duration__s=agent_full_duration, ), base_metrics=agent_base_metrics, refined_metrics=agent_refined_metrics, diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index 45de030a1d..bbd3d1166e 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -890,8 +890,8 @@ def log_agent_metrics( persona_id=persona_id, agent_type=agent_type, start_time=start_time, - base_duration_s=agent_timings.base_duration_s, - full_duration_s=agent_timings.full_duration_s, + base_duration__s=agent_timings.base_duration__s, + full_duration__s=agent_timings.full_duration__s, base_metrics=vars(agent_base_metrics), refined_metrics=vars(agent_refined_metrics), all_metrics=vars(agent_additional_metrics), diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 1f817254ec..35237464a8 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1689,8 +1689,8 @@ class AgentSearchMetrics(Base): ) agent_type: Mapped[str] = mapped_column(String) start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) - base_duration_s: Mapped[float] = mapped_column(Float) - full_duration_s: Mapped[float] = mapped_column(Float) + base_duration__s: Mapped[float] = mapped_column(Float) + full_duration__s: Mapped[float] = mapped_column(Float) base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) From 239f2f271861a4d25d7b6942e242fa360bd1e36d Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 6 Jan 2025 18:37:04 -0800 Subject: [PATCH 55/78] cleanup plus get-chat-session API v1 --- .../nodes/answer_check.py | 30 --------- .../nodes/answer_generation.py | 39 ------------ .../nodes/format_answer.py | 29 --------- .../nodes/ingest_retrieval.py | 19 ------ .../answer_follow_up_question/states.py | 58 ----------------- .../nodes/answer_generation.py | 11 ---- .../answer_question/nodes/format_answer.py | 10 --- .../agent_search/expanded_retrieval/nodes.py | 63 +++++++++---------- .../expanded_retrieval/prompts.py | 0 backend/onyx/agent_search/main/nodes.py | 1 + backend/onyx/agent_search/run_graph.py | 3 +- .../shared_graph_utils/agent_prompt_ops.py | 10 --- backend/onyx/db/chat.py | 54 ++++++++++++++-- backend/onyx/db/models.py | 9 ++- backend/onyx/server/query_and_chat/models.py | 19 +++++- .../search/search_tool.py | 18 ++++-- 16 files changed, 117 insertions(+), 256 deletions(-) delete mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py delete mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py delete mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py delete mode 100644 backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py delete mode 100644 backend/onyx/agent_search/answer_follow_up_question/states.py delete mode 100644 backend/onyx/agent_search/expanded_retrieval/prompts.py diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py deleted file mode 100644 index 6349552f34..0000000000 --- a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_check.py +++ /dev/null @@ -1,30 +0,0 @@ -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QACheckUpdate -from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT - - -def answer_check(state: AnswerQuestionState) -> QACheckUpdate: - msg = [ - HumanMessage( - content=SUB_CHECK_PROMPT.format( - question=state["question"], - base_answer=state["answer"], - ) - ) - ] - - fast_llm = state["subgraph_fast_llm"] - response = list( - fast_llm.stream( - prompt=msg, - ) - ) - - quality_str = merge_message_runs(response, chunk_separator="")[0].content - - return QACheckUpdate( - answer_quality=quality_str, - ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py deleted file mode 100644 index 9b742d8e74..0000000000 --- a/backend/onyx/agent_search/answer_follow_up_question/nodes/answer_generation.py +++ /dev/null @@ -1,39 +0,0 @@ -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs - -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QAGenerationUpdate -from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT -from onyx.agent_search.shared_graph_utils.utils import format_docs -from onyx.utils.logger import setup_logger - -logger = setup_logger() - - -def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: - question = state["question"] - docs = state["documents"] - - logger.info(f"Number of verified retrieval docs: {len(docs)}") - - msg = [ - HumanMessage( - content=BASE_RAG_PROMPT.format( - question=question, - context=format_docs(docs), - original_question=state["subgraph_config"].search_request.query, - ) - ) - ] - - fast_llm = state["subgraph_fast_llm"] - response = list( - fast_llm.stream( - prompt=msg, - ) - ) - - answer_str = merge_message_runs(response, chunk_separator="")[0].content - return QAGenerationUpdate( - answer=answer_str, - ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py deleted file mode 100644 index 4350f20165..0000000000 --- a/backend/onyx/agent_search/answer_follow_up_question/nodes/format_answer.py +++ /dev/null @@ -1,29 +0,0 @@ -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QuestionAnswerResults - - -def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: - # sub_question_retrieval_stats = state["sub_question_retrieval_stats"] - # if sub_question_retrieval_stats is None: - # sub_question_retrieval_stats = [] - # elif isinstance(sub_question_retrieval_stats, list): - # sub_question_retrieval_stats = sub_question_retrieval_stats - # if isinstance(sub_question_retrieval_stats[0], list): - # sub_question_retrieval_stats = sub_question_retrieval_stats[0] - # else: - # sub_question_retrieval_stats = [sub_question_retrieval_stats] - - return AnswerQuestionOutput( - answer_results=[ - QuestionAnswerResults( - question=state["question"], - quality=state["answer_quality"], - answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - sub_question_retrieval_stats=state["sub_question_retrieval_stats"], - question_id=state["question_id"], - ) - ], - ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py deleted file mode 100644 index cc9e5989ff..0000000000 --- a/backend/onyx/agent_search/answer_follow_up_question/nodes/ingest_retrieval.py +++ /dev/null @@ -1,19 +0,0 @@ -from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput -from onyx.agent_search.shared_graph_utils.models import AgentChunkStats - - -def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: - sub_question_retrieval_stats = state[ - "expanded_retrieval_result" - ].sub_question_retrieval_stats - if sub_question_retrieval_stats is None: - sub_question_retrieval_stats = [AgentChunkStats()] - - return RetrievalIngestionUpdate( - expanded_retrieval_results=state[ - "expanded_retrieval_result" - ].expanded_queries_results, - documents=state["expanded_retrieval_result"].all_documents, - sub_question_retrieval_stats=sub_question_retrieval_stats, - ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/states.py b/backend/onyx/agent_search/answer_follow_up_question/states.py deleted file mode 100644 index 28f4dc2134..0000000000 --- a/backend/onyx/agent_search/answer_follow_up_question/states.py +++ /dev/null @@ -1,58 +0,0 @@ -from operator import add -from typing import Annotated -from typing import TypedDict - -from onyx.agent_search.answer_question.models import QuestionAnswerResults -from onyx.agent_search.core_state import SubgraphCoreState -from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.shared_graph_utils.models import AgentChunkStats -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -from onyx.context.search.models import InferenceSection - - -## Update States -class QACheckUpdate(TypedDict): - answer_quality: str - - -class QAGenerationUpdate(TypedDict): - answer: str - # answer_stat: AnswerStats - - -class RetrievalIngestionUpdate(TypedDict): - expanded_retrieval_results: list[QueryResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] - sub_question_retrieval_stats: AgentChunkStats - - -## Graph Input State - - -class AnswerQuestionInput(SubgraphCoreState): - question: str - - -## Graph State - - -class AnswerQuestionState( - AnswerQuestionInput, - QAGenerationUpdate, - QACheckUpdate, - RetrievalIngestionUpdate, -): - pass - - -## Graph Output State - - -class AnswerQuestionOutput(TypedDict): - """ - This is a list of results even though each call of this subgraph only returns one result. - This is because if we parallelize the answer query subgraph, there will be multiple - results in a list so the add operator is used to add them together. - """ - - answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 51f36b9180..538e279242 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -38,17 +38,6 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: persona_specification=persona_specification, ) - # msg = [ - # HumanMessage( - # content=BASE_RAG_PROMPT.format( - # question=question, - # context=format_docs(docs), - # original_question=state["subgraph_search_request"].query, - # persona_specification=persona_specification, - # ) - # ) - # ] - fast_llm = state["subgraph_fast_llm"] response: list[str | list[str | dict[str, Any]]] = [] for message in fast_llm.stream( diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 3bd1e9569f..23ffd23939 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -4,16 +4,6 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: - # sub_question_retrieval_stats = state["sub_question_retrieval_stats"] - # if sub_question_retrieval_stats is None: - # sub_question_retrieval_stats = [] - # elif isinstance(sub_question_retrieval_stats, list): - # sub_question_retrieval_stats = sub_question_retrieval_stats - # if isinstance(sub_question_retrieval_stats[0], list): - # sub_question_retrieval_stats = sub_question_retrieval_stats[0] - # else: - # sub_question_retrieval_stats = [sub_question_retrieval_stats] - return AnswerQuestionOutput( answer_results=[ QuestionAnswerResults( diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 04a3221e71..7ff034fd59 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -39,6 +39,7 @@ from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import retrieval_preprocessing from onyx.context.search.postprocessing.postprocessing import rerank_sections +from onyx.db.engine import get_session_context_manager from onyx.llm.interfaces import LLM from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, @@ -87,20 +88,6 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: rewritten_queries = llm_response.split("\n") - if state["subgraph_config"].use_persistence: - # Persist sub-queries to database - - # for query in rewritten_queries: - # sub_queries.append( - # create_sub_query( - # db_session=db_session, - # chat_session_id=chat_session_id, - # parent_question_id=sub_question_id, - # sub_query=query.strip(), - # ) - # ) - pass - return QueryExpansionUpdate( expanded_queries=rewritten_queries, ) @@ -127,21 +114,25 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: expanded_retrieval_results=[], retrieved_documents=[], ) - for tool_response in search_tool.run( - query=query_to_retrieve, force_no_rerank="True" - ): - if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: - retrieved_docs = cast( - list[InferenceSection], tool_response.response.top_sections + + with get_session_context_manager() as db_session: + for tool_response in search_tool.run( + query=query_to_retrieve, + force_no_rerank=True, + alternate_db_session=db_session, + ): + if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: + retrieved_docs = cast( + list[InferenceSection], tool_response.response.top_sections + ) + dispatch_custom_event( + "tool_response", + ExtendedToolResponse( + id=tool_response.id, + sub_question_id=state["sub_question_id"] or make_question_id(0, 0), + response=tool_response.response, + ), ) - dispatch_custom_event( - "tool_response", - ExtendedToolResponse( - id=tool_response.id, - sub_question_id=state["sub_question_id"] or make_question_id(0, 0), - response=tool_response.response, - ), - ) retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] pre_rerank_docs = retrieved_docs @@ -172,7 +163,6 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: def verification_kickoff( state: ExpandedRetrievalState, ) -> Command[Literal["doc_verification"]]: - # TODO: stream deduped docs? documents = state["retrieved_documents"] verification_question = state.get( "question", state["subgraph_config"].search_request.query @@ -239,12 +229,13 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: # then create the list of reranked sections question = state.get("question", state["subgraph_config"].search_request.query) - _search_query = retrieval_preprocessing( - search_request=SearchRequest(query=question), - user=state["subgraph_search_tool"].user, # bit of a hack - llm=state["subgraph_fast_llm"], - db_session=state["subgraph_db_session"], - ) + with get_session_context_manager() as db_session: + _search_query = retrieval_preprocessing( + search_request=SearchRequest(query=question), + user=state["subgraph_search_tool"].user, # bit of a hack + llm=state["subgraph_fast_llm"], + db_session=db_session, + ) # skip section filtering @@ -266,6 +257,8 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: else: fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={}) + # TODO: stream deduped docs here, or decide to use search tool ranking/verification + return DocRerankingUpdate( reranked_documents=[ doc for doc in reranked_documents if type(doc) == InferenceSection diff --git a/backend/onyx/agent_search/expanded_retrieval/prompts.py b/backend/onyx/agent_search/expanded_retrieval/prompts.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index c5a554335a..e45195f9a7 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -676,6 +676,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: ) ] + # TODO: stream refined answer # Grader model = state["fast_llm"] response = model.invoke(msg) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 6c12027e03..122ca95589 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -138,7 +138,6 @@ def run_main_graph( if __name__ == "__main__": from onyx.llm.factory import get_default_llms - from onyx.db.persona import get_persona_by_id graph = main_graph_builder() compiled_graph = graph.compile() @@ -153,7 +152,7 @@ def run_main_graph( config, search_tool = get_test_config( db_session, primary_llm, fast_llm, search_request ) - search_request.persona = get_persona_by_id(1, None, db_session) + # search_request.persona = get_persona_by_id(1, None, db_session) config.use_persistence = True # with open("output.txt", "w") as f: diff --git a/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py b/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py index 29bbeffaee..4f0cf106b4 100644 --- a/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py +++ b/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py @@ -31,14 +31,4 @@ def build_sub_question_answer_prompt( ) ) - # ai_message = AIMessage(content='' - # ) - - # tool_message = ToolMessage( - # content=docs_str, - # tool_call_id='agent_search_call', - # name="search_results" - # ) - return [system_message, human_message] - # return [system_message, human_message, ai_message, tool_message] diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index bbd3d1166e..b03f615692 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -44,6 +44,8 @@ from onyx.llm.override_models import LLMOverride from onyx.llm.override_models import PromptOverride from onyx.server.query_and_chat.models import ChatMessageDetail +from onyx.server.query_and_chat.models import SubQueryDetail +from onyx.server.query_and_chat.models import SubQuestionDetail from onyx.tools.tool_runner import ToolCallFinalResult from onyx.utils.logger import setup_logger @@ -486,6 +488,7 @@ def get_chat_messages_by_session( prefetch_tool_calls: bool = False, ) -> list[ChatMessage]: if not skip_permission_check: + # bug if we ever call this expecting the permission check to not be skipped get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session ) @@ -497,7 +500,12 @@ def get_chat_messages_by_session( ) if prefetch_tool_calls: - stmt = stmt.options(joinedload(ChatMessage.tool_call)) + stmt = stmt.options( + joinedload(ChatMessage.tool_call), + joinedload(ChatMessage.sub_questions).joinedload( + AgentSubQuestion.sub_queries + ), + ) result = db_session.scalars(stmt).unique().all() else: result = db_session.scalars(stmt).all() @@ -827,14 +835,45 @@ def translate_db_search_doc_to_server_search_doc( ) -def get_retrieval_docs_from_chat_message( - chat_message: ChatMessage, remove_doc_content: bool = False +def translate_db_sub_questions_to_server_objects( + db_sub_questions: list[AgentSubQuestion], +) -> list[SubQuestionDetail]: + sub_questions = [] + for sub_question in db_sub_questions: + sub_queries = [] + docs: list[SearchDoc] = [] + for sub_query in sub_question.sub_queries: + doc_ids = [doc.id for doc in sub_query.search_docs] + sub_queries.append( + SubQueryDetail( + query=sub_query.sub_query, + query_id=sub_query.id, + doc_ids=doc_ids, + ) + ) + docs += sub_query.search_docs + + sub_questions.append( + SubQuestionDetail( + level=sub_question.level, + level_question_nr=sub_question.level_question_nr, + question=sub_question.sub_question, + answer=sub_question.sub_answer, + sub_queries=sub_queries, + context_docs=get_retrieval_docs_from_search_docs(docs), + ) + ) + return sub_questions + + +def get_retrieval_docs_from_search_docs( + search_docs: list[SearchDoc], remove_doc_content: bool = False ) -> RetrievalDocs: top_documents = [ translate_db_search_doc_to_server_search_doc( db_doc, remove_doc_content=remove_doc_content ) - for db_doc in chat_message.search_docs + for db_doc in search_docs ] top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore return RetrievalDocs(top_documents=top_documents) @@ -851,8 +890,8 @@ def translate_db_message_to_chat_message_detail( latest_child_message=chat_message.latest_child_message, message=chat_message.message, rephrased_query=chat_message.rephrased_query, - context_docs=get_retrieval_docs_from_chat_message( - chat_message, remove_doc_content=remove_doc_content + context_docs=get_retrieval_docs_from_search_docs( + chat_message.search_docs, remove_doc_content=remove_doc_content ), message_type=chat_message.message_type, time_sent=chat_message.time_sent, @@ -867,6 +906,9 @@ def translate_db_message_to_chat_message_detail( else None, alternate_assistant_id=chat_message.alternate_assistant_id, overridden_model=chat_message.overridden_model, + sub_questions=translate_db_sub_questions_to_server_objects( + chat_message.sub_questions + ), ) return chat_msg_detail diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 35237464a8..279c2259b5 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -1100,6 +1100,11 @@ class ChatMessage(Base): uselist=False, ) + sub_questions: Mapped[list["AgentSubQuestion"]] = relationship( + "AgentSubQuestion", + back_populates="primary_message", + ) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( "StandardAnswer", secondary=ChatMessage__StandardAnswer.__table__, @@ -1158,7 +1163,9 @@ class AgentSubQuestion(Base): # Relationships primary_message: Mapped["ChatMessage"] = relationship( - "ChatMessage", foreign_keys=[primary_question_id] + "ChatMessage", + foreign_keys=[primary_question_id], + back_populates="sub_questions", ) chat_session: Mapped["ChatSession"] = relationship("ChatSession") sub_queries: Mapped[list["AgentSubQuery"]] = relationship( diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index 185d6cb0e2..d5ff6aa243 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -189,6 +189,22 @@ def check_click_or_search_feedback(self) -> "SearchFeedbackRequest": return self +class SubQueryDetail(BaseModel): + query: str + query_id: int + # TODO: store these to enable per-query doc selection + doc_ids: list[int] | None = None + + +class SubQuestionDetail(BaseModel): + level: int + level_question_nr: int + question: str + answer: str + sub_queries: list[SubQueryDetail] | None = None + context_docs: RetrievalDocs | None = None + + class ChatMessageDetail(BaseModel): message_id: int parent_message: int | None = None @@ -200,9 +216,10 @@ class ChatMessageDetail(BaseModel): time_sent: datetime overridden_model: str | None alternate_assistant_id: int | None = None - # Dict mapping citation number to db_doc_id chat_session_id: UUID | None = None + # Dict mapping citation number to db_doc_id citations: dict[int, int] | None = None + sub_questions: list[SubQuestionDetail] | None = None files: list[FileDescriptor] tool_call: ToolCallFinalResult | None diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 5b2fe4bebc..2af8ddc9a3 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -283,10 +283,11 @@ def _build_response_for_specified_sections( yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) - def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["query"]) # kind of awkward to require this to be str, but it's "True" or "False" - force_no_rerank = kwargs.get("force_no_rerank", "False") + force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) + alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) if self.selected_sections: yield from self._build_response_for_specified_sections(query) @@ -306,8 +307,15 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: self.retrieval_options.offset if self.retrieval_options else None ), limit=self.retrieval_options.limit if self.retrieval_options else None, - rerank_settings=None - if force_no_rerank == "True" + rerank_settings=RerankingDetails( + rerank_model_name=None, + rerank_api_url=None, + rerank_provider_type=None, + rerank_api_key=None, + num_rerank=0, + disable_rerank_for_streaming=True, + ) + if force_no_rerank else self.rerank_settings, chunks_above=self.chunks_above, chunks_below=self.chunks_below, @@ -322,7 +330,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: llm=self.llm, fast_llm=self.fast_llm, bypass_acl=self.bypass_acl, - db_session=self.db_session, + db_session=alternate_db_session or self.db_session, prompt_config=self.prompt_config, ) self.search_pipeline = search_pipeline # used for agent_search metrics From 3a38407bcaed9ece69b787db736f01f942fd3fbc Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 6 Jan 2025 18:43:37 -0800 Subject: [PATCH 56/78] added type to ChatPacket --- backend/onyx/chat/process_message.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index f0060e6f27..dd292a4bb6 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -25,6 +25,7 @@ from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import ProSearchConfig +from onyx.chat.models import ProSearchPacket from onyx.chat.models import QADocsResponse from onyx.chat.models import StreamingError from onyx.chat.models import StreamStopInfo @@ -284,6 +285,7 @@ def _get_force_search_settings( | MessageSpecificCitations | MessageResponseIDInfo | StreamStopInfo + | ProSearchPacket ) ChatPacketStream = Iterator[ChatPacket] From 7898f38a5d9b8317fcccde5909119ea76db94671 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Mon, 6 Jan 2025 19:12:49 -0800 Subject: [PATCH 57/78] refactor and separate id fields --- .../nodes/answer_generation.py | 9 ++++-- .../agent_search/expanded_retrieval/nodes.py | 27 ++++++++++++----- backend/onyx/agent_search/main/nodes.py | 9 +++--- backend/onyx/agent_search/run_graph.py | 12 ++++---- backend/onyx/chat/models.py | 30 +++++++++---------- 5 files changed, 50 insertions(+), 37 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 538e279242..062c588b06 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -11,7 +11,8 @@ from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt -from onyx.chat.models import SubAnswer +from onyx.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import SubAnswerPiece from onyx.utils.logger import setup_logger logger = setup_logger() @@ -20,6 +21,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question = state["question"] docs = state["documents"] + level, question_nr = parse_question_id(state["question_id"]) persona_prompt = get_persona_prompt(state["subgraph_config"].search_request.persona) if len(persona_prompt) > 0: @@ -51,9 +53,10 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: ) dispatch_custom_event( "sub_answers", - SubAnswer( + SubAnswerPiece( sub_answer=content, - sub_question_id=state["question_id"], + level=level, + level_question_nr=question_nr, ), ) response.append(content) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 7ff034fd59..f1eb798d4c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -29,9 +29,9 @@ from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT from onyx.agent_search.shared_graph_utils.utils import dispatch_separated -from onyx.agent_search.shared_graph_utils.utils import make_question_id +from onyx.agent_search.shared_graph_utils.utils import parse_question_id from onyx.chat.models import ExtendedToolResponse -from onyx.chat.models import SubQuery +from onyx.chat.models import SubQueryPiece from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_STATS @@ -49,11 +49,16 @@ logger = setup_logger() -def dispatch_subquery(subquestion_id: str) -> Callable[[str, int], None]: +def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]: def helper(token: str, num: int) -> None: dispatch_custom_event( "subqueries", - SubQuery(sub_query=token, sub_question_id=subquestion_id, query_id=num), + SubQueryPiece( + sub_query=token, + level=level, + level_question_nr=question_nr, + query_id=num, + ), ) return helper @@ -69,7 +74,9 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: chat_session_id = state["subgraph_config"].chat_session_id sub_question_id = state.get("sub_question_id") if sub_question_id is None: - sub_question_id = make_question_id(0, 0) # 0_0 for original question + level, question_nr = 0, 0 + else: + level, question_nr = parse_question_id(sub_question_id) if chat_session_id is None: raise ValueError("chat_session_id must be provided for agent search") @@ -81,7 +88,7 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: ] llm_response_list = dispatch_separated( - llm.stream(prompt=msg), dispatch_subquery(sub_question_id) + llm.stream(prompt=msg), dispatch_subquery(level, question_nr) ) llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content @@ -125,12 +132,18 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: retrieved_docs = cast( list[InferenceSection], tool_response.response.top_sections ) + level, question_nr = ( + parse_question_id(state["sub_question_id"]) + if state["sub_question_id"] + else (0, 0) + ) dispatch_custom_event( "tool_response", ExtendedToolResponse( id=tool_response.id, - sub_question_id=state["sub_question_id"] or make_question_id(0, 0), response=tool_response.response, + level=level, + level_question_nr=question_nr, ), ) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index e45195f9a7..822f759174 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -62,7 +62,7 @@ from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agent_search.shared_graph_utils.utils import make_question_id -from onyx.chat.models import SubQuestion +from onyx.chat.models import SubQuestionPiece from onyx.db.chat import log_agent_metrics from onyx.db.chat import log_agent_sub_question_results from onyx.utils.logger import setup_logger @@ -74,11 +74,10 @@ def dispatch_subquestion(level: int) -> Callable[[str, int], None]: def helper(sub_question_part: str, num: int) -> None: dispatch_custom_event( "decomp_qs", - SubQuestion( + SubQuestionPiece( sub_question=sub_question_part, - question_id=make_question_id( - level, num + 1 - ), # question 0 reserved for original question if used + level=level, + level_question_nr=num + 1, ), ) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 122ca95589..c95417a1f4 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -14,9 +14,9 @@ from onyx.chat.models import AnswerStream from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import ProSearchConfig -from onyx.chat.models import SubAnswer -from onyx.chat.models import SubQuery -from onyx.chat.models import SubQuestion +from onyx.chat.models import SubAnswerPiece +from onyx.chat.models import SubQueryPiece +from onyx.chat.models import SubQuestionPiece from onyx.chat.models import ToolResponse from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager @@ -43,11 +43,11 @@ def _parse_agent_event( if event_type == "on_custom_event": # TODO: different AnswerStream types for different events if event["name"] == "decomp_qs": - return cast(SubQuestion, event["data"]) + return cast(SubQuestionPiece, event["data"]) elif event["name"] == "subqueries": - return cast(SubQuery, event["data"]) + return cast(SubQueryPiece, event["data"]) elif event["name"] == "sub_answers": - return cast(SubAnswer, event["data"]) + return cast(SubAnswerPiece, event["data"]) elif event["name"] == "main_answer": return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) elif event["name"] == "tool_response": diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 233d998211..e28231ce97 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -347,35 +347,33 @@ def from_model( ) -class SubQuery(BaseModel): +class SubQueryPiece(BaseModel): sub_query: str - sub_question_id: str # _ + level: int + level_question_nr: int query_id: int - @model_validator(mode="after") - def check_sub_question_id(self) -> "SubQuery": - if len(self.sub_question_id.split("_")) != 2: - raise ValueError( - "sub_question_id must be in the format _" - ) - return self - -class SubAnswer(BaseModel): +class SubAnswerPiece(BaseModel): sub_answer: str - sub_question_id: str # _ + level: int + level_question_nr: int -class SubQuestion(BaseModel): - question_id: str # _ +class SubQuestionPiece(BaseModel): sub_question: str + level: int + level_question_nr: int class ExtendedToolResponse(ToolResponse): - sub_question_id: str # _ + level: int + level_question_nr: int -ProSearchPacket = SubQuestion | SubAnswer | SubQuery | ExtendedToolResponse +ProSearchPacket = ( + SubQuestionPiece | SubAnswerPiece | SubQueryPiece | ExtendedToolResponse +) AnswerPacket = ( AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse From fd1191637bc00af553334fe88744a2dca5bc2421 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 7 Jan 2025 08:58:10 -0800 Subject: [PATCH 58/78] extra info logging and allow_refinement flag --- backend/onyx/agent_search/main/nodes.py | 89 ++++++++++++------- backend/onyx/agent_search/main/states.py | 21 ++--- backend/onyx/agent_search/run_graph.py | 16 +++- .../shared_graph_utils/operators.py | 20 +++++ backend/onyx/chat/models.py | 3 + 5 files changed, 99 insertions(+), 50 deletions(-) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index c5a554335a..77d417b004 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -27,7 +27,6 @@ from onyx.agent_search.main.states import DecompAnswersUpdate from onyx.agent_search.main.states import EntityTermExtractionUpdate from onyx.agent_search.main.states import ExpandedRetrievalUpdate -from onyx.agent_search.main.states import FollowUpDecompAnswersUpdate from onyx.agent_search.main.states import FollowUpSubQuestionsUpdate from onyx.agent_search.main.states import InitialAnswerBASEUpdate from onyx.agent_search.main.states import InitialAnswerQualityUpdate @@ -62,6 +61,7 @@ from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agent_search.shared_graph_utils.utils import make_question_id +from onyx.agent_search.shared_graph_utils.utils import parse_question_id from onyx.chat.models import SubQuestion from onyx.db.chat import log_agent_metrics from onyx.db.chat import log_agent_sub_question_results @@ -147,6 +147,8 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: return BaseDecompUpdate( initial_decomp_questions=decomp_list, agent_start_time=agent_start_time, + agent_refined_start_time=None, + agent_refined_end_time=None, ) @@ -397,6 +399,15 @@ def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: logger.info(f"XXXXXX--{now_start}--XXX---GENERATE ENTITIES & TERMS---") + if not state["config"].allow_refinement: + return EntityTermExtractionUpdate( + entity_retlation_term_extractions=EntityRelationshipTermExtraction( + entities=[], + relationships=[], + terms=[], + ) + ) + # first four lines duplicates from generate_initial_answer question = state["config"].search_request.query sub_question_docs = state["documents"] @@ -574,7 +585,7 @@ def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: f"XXXXXX--{now_end}--{now_end - now_start}--XXX---REFINED ANSWER DECISION END---" ) - if False: + if not state["config"].allow_refinement or True: return RequireRefinedAnswerUpdate(require_refined_answer=False) else: @@ -602,7 +613,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: revision_doc_effectiveness = 10.0 decomp_answer_results = state["decomp_answer_results"] - revised_answer_results = state["follow_up_decomp_answer_results"] + # revised_answer_results = state["follow_up_decomp_answer_results"] good_qa_list: list[str] = [] decomp_questions = [] @@ -610,24 +621,25 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: initial_good_sub_questions: list[str] = [] new_revised_good_sub_questions: list[str] = [] - for answer_set in [decomp_answer_results, revised_answer_results]: - for decomp_answer_result in answer_set: - decomp_questions.append(decomp_answer_result.question) - if ( - decomp_answer_result.quality.lower().startswith("yes") - and len(decomp_answer_result.answer) > 0 - and decomp_answer_result.answer != "I don't know" - ): - good_qa_list.append( - SUB_QUESTION_ANSWER_TEMPLATE.format( - sub_question=decomp_answer_result.question, - sub_answer=decomp_answer_result.answer, - ) + for decomp_answer_result in decomp_answer_results: + question_level, _ = parse_question_id(decomp_answer_result.question_id) + + decomp_questions.append(decomp_answer_result.question) + if ( + decomp_answer_result.quality.lower().startswith("yes") + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != "I don't know" + ): + good_qa_list.append( + SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.question, + sub_answer=decomp_answer_result.answer, ) - if answer_set == decomp_answer_results: - initial_good_sub_questions.append(decomp_answer_result.question) - else: - new_revised_good_sub_questions.append(decomp_answer_result.question) + ) + if question_level == 0: + initial_good_sub_questions.append(decomp_answer_result.question) + else: + new_revised_good_sub_questions.append(decomp_answer_result.question) initial_good_sub_questions = list(set(initial_good_sub_questions)) new_revised_good_sub_questions = list(set(new_revised_good_sub_questions)) @@ -758,9 +770,12 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: ) agent_refined_end_time = datetime.now() - agent_refined_duration = ( - agent_refined_end_time - state["agent_refined_start_time"] - ).total_seconds() + if state["agent_refined_start_time"]: + agent_refined_duration = ( + agent_refined_end_time - state["agent_refined_start_time"] + ).total_seconds() + else: + agent_refined_duration = None agent_refined_metrics = AgentRefinedMetrics( refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency, @@ -861,7 +876,7 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: def ingest_follow_up_answers( state: AnswerQuestionOutput, -) -> FollowUpDecompAnswersUpdate: +) -> DecompAnswersUpdate: now_start = datetime.now() logger.info(f"XXXXXX--{now_start}--XXX---INGEST FOLLOW UP ANSWERS---") @@ -877,11 +892,11 @@ def ingest_follow_up_answers( f"XXXXXX--{now_end}--{now_end - now_start}--XXX---INGEST FOLLOW UP ANSWERS END---" ) - return FollowUpDecompAnswersUpdate( + return DecompAnswersUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here - follow_up_documents=dedup_inference_sections(documents, []), - follow_up_decomp_answer_results=answer_results, + documents=dedup_inference_sections(documents, []), + decomp_answer_results=answer_results, ) @@ -892,9 +907,12 @@ def logging_node(state: MainState) -> MainOutput: agent_start_time = state["agent_start_time"] agent_base_end_time = state["agent_base_end_time"] - agent_refined_start_time = state["agent_refined_start_time"] - agent_refined_end_time = state["agent_refined_end_time"] - agent_end_time = max(agent_base_end_time, agent_refined_end_time) + agent_refined_start_time = state["agent_refined_start_time"] or None + agent_refined_end_time = state["agent_refined_end_time"] or None + if agent_refined_end_time is not None: + agent_end_time = agent_refined_end_time + else: + agent_end_time = agent_base_end_time if agent_base_end_time: agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds() @@ -902,9 +920,12 @@ def logging_node(state: MainState) -> MainOutput: agent_base_duration = None if agent_refined_end_time: - agent_refined_duration = ( - agent_refined_end_time - agent_refined_start_time - ).total_seconds() + if agent_refined_start_time and agent_refined_end_time: + agent_refined_duration = ( + agent_refined_end_time - agent_refined_start_time + ).total_seconds() + else: + agent_refined_duration = None else: agent_refined_duration = None @@ -960,7 +981,7 @@ def logging_node(state: MainState) -> MainOutput: db_session = state["db_session"] chat_session_id = state["config"].chat_session_id primary_message_id = state["config"].message_id - sub_question_answer_results = state["follow_up_decomp_answer_results"] + sub_question_answer_results = state["decomp_answer_results"] log_agent_sub_question_results( db_session=db_session, diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index a89495d955..28926616d9 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -15,6 +15,7 @@ from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.agent_search.shared_graph_utils.operators import dedup_question_answer_results from onyx.context.search.models import InferenceSection @@ -25,6 +26,8 @@ class BaseDecompUpdate(TypedDict): agent_start_time: datetime + agent_refined_start_time: datetime | None + agent_refined_end_time: datetime | None initial_decomp_questions: list[str] @@ -58,7 +61,9 @@ class RequireRefinedAnswerUpdate(TypedDict): class DecompAnswersUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] - decomp_answer_results: Annotated[list[QuestionAnswerResults], add] + decomp_answer_results: Annotated[ + list[QuestionAnswerResults], dedup_question_answer_results + ] class FollowUpDecompAnswersUpdate(TypedDict): @@ -80,20 +85,11 @@ class EntityTermExtractionUpdate(TypedDict): class FollowUpSubQuestionsUpdate(TypedDict): follow_up_sub_questions: dict[int, FollowUpSubQuestion] - agent_refined_start_time: datetime - - -class FollowUpAnswerQuestionOutput(TypedDict): - """ - This is a list of results even though each call of this subgraph only returns one result. - This is because if we parallelize the answer query subgraph, there will be multiple - results in a list so the add operator is used to add them together. - """ - - follow_up_answer_results: Annotated[list[QuestionAnswerResults], add] + agent_refined_start_time: datetime | None ## Graph Input State +## Graph Input State class MainInput(CoreState): @@ -115,7 +111,6 @@ class MainState( InitialAnswerQualityUpdate, RequireRefinedAnswerUpdate, FollowUpSubQuestionsUpdate, - FollowUpAnswerQuestionOutput, FollowUpDecompAnswersUpdate, RefinedAnswerUpdate, ): diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 6c12027e03..b8036a346c 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,6 +1,7 @@ import asyncio from collections.abc import AsyncIterable from collections.abc import Iterable +from datetime import datetime from typing import cast from langchain_core.runnables.schema import StreamEvent @@ -140,8 +141,13 @@ def run_main_graph( from onyx.llm.factory import get_default_llms from onyx.db.persona import get_persona_by_id + now_start = datetime.now() + logger.info(f"Start at {now_start}") + graph = main_graph_builder() compiled_graph = graph.compile() + now_end = datetime.now() + logger.info(f"Graph compiled in {now_end - now_start} seconds") primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( # query="what can you do with gitlab?", @@ -157,15 +163,19 @@ def run_main_graph( config.use_persistence = True # with open("output.txt", "w") as f: - tool_responses = [] + tool_responses: list = [] for output in run_graph( compiled_graph, config, search_tool, primary_llm, fast_llm, db_session ): + # pass if isinstance(output, OnyxAnswerPiece): tool_responses.append("|") elif isinstance(output, ToolCallKickoff): pass elif isinstance(output, ToolResponse): tool_responses.append(output.response) - for tool_response in tool_responses: - logger.info(tool_response) + elif isinstance(output, SubQuestion): + logger.info(output.sub_question, end=" | ") + + # for tool_response in tool_responses: + # logger.info(tool_response) diff --git a/backend/onyx/agent_search/shared_graph_utils/operators.py b/backend/onyx/agent_search/shared_graph_utils/operators.py index d75eb54cd5..bd28d64165 100644 --- a/backend/onyx/agent_search/shared_graph_utils/operators.py +++ b/backend/onyx/agent_search/shared_graph_utils/operators.py @@ -1,3 +1,4 @@ +from onyx.agent_search.answer_question.models import QuestionAnswerResults from onyx.chat.prune_and_merge import _merge_sections from onyx.context.search.models import InferenceSection @@ -7,3 +8,22 @@ def dedup_inference_sections( ) -> list[InferenceSection]: deduped = _merge_sections(list1 + list2) return deduped + + +def dedup_question_answer_results( + question_answer_results_1: list[QuestionAnswerResults], + question_answer_results_2: list[QuestionAnswerResults], +) -> list[QuestionAnswerResults]: + deduped_question_answer_results: list[ + QuestionAnswerResults + ] = question_answer_results_1 + utilized_question_ids: set[str] = set( + [x.question_id for x in question_answer_results_1] + ) + + for question_answer_result in question_answer_results_2: + if question_answer_result.question_id not in utilized_question_ids: + deduped_question_answer_results.append(question_answer_result) + utilized_question_ids.add(question_answer_result.question_id) + + return deduped_question_answer_results diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 233d998211..7a79d8483f 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -221,6 +221,9 @@ class ProSearchConfig(BaseModel): # Whether to persistence data for the Pro Search (turned off for testing) use_persistence: bool = True + # Whether to allow creation of refinement questions (and entity extraction, etc.) + allow_refinement: bool = False + AnswerQuestionPossibleReturn = ( OnyxAnswerPiece From 462db23683668b169bceebfa5fc6d571cdc87b2d Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 7 Jan 2025 10:14:21 -0800 Subject: [PATCH 59/78] updated run_graph to stream out answer tokens --- backend/onyx/agent_search/run_graph.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index d9fdb664c8..3fe89e31f4 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -167,14 +167,21 @@ def run_main_graph( compiled_graph, config, search_tool, primary_llm, fast_llm, db_session ): # pass - if isinstance(output, OnyxAnswerPiece): - tool_responses.append("|") - elif isinstance(output, ToolCallKickoff): + + if isinstance(output, ToolCallKickoff): pass elif isinstance(output, ToolResponse): tool_responses.append(output.response) elif isinstance(output, SubQuestionPiece): - logger.info(f"{output.sub_question} | ") + logger.info( + f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | " + ) + elif isinstance(output, SubAnswerPiece): + logger.info( + f" ---- SA {output.level} - {output.level_question_nr} {output.sub_answer} | " + ) + elif isinstance(output, OnyxAnswerPiece): + logger.info(f" ---------- FA {output.answer_piece} | ") # for tool_response in tool_responses: # logger.info(tool_response) From 8650b8ff51c7487e1099ea8e26f297c99bea23e6 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 7 Jan 2025 13:18:24 -0800 Subject: [PATCH 60/78] main nodes/edges function renaming --- .../answer_follow_up_question/edges.py | 2 +- .../graph_builder.py | 8 ++-- backend/onyx/agent_search/main/edges.py | 8 +++- .../onyx/agent_search/main/graph_builder.py | 44 +++++++++---------- backend/onyx/agent_search/main/nodes.py | 30 ++++++++----- 5 files changed, 51 insertions(+), 41 deletions(-) diff --git a/backend/onyx/agent_search/answer_follow_up_question/edges.py b/backend/onyx/agent_search/answer_follow_up_question/edges.py index 87c550cf6b..36d7a17308 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/edges.py +++ b/backend/onyx/agent_search/answer_follow_up_question/edges.py @@ -10,7 +10,7 @@ logger = setup_logger() -def send_to_expanded_follow_up_retrieval(state: AnswerQuestionInput) -> Send | Hashable: +def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable: logger.info("sending to expanded retrieval for follow up question via edge") return Send( diff --git a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py index dfee1a6a13..7f2766888f 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py @@ -3,7 +3,7 @@ from langgraph.graph import StateGraph from onyx.agent_search.answer_follow_up_question.edges import ( - send_to_expanded_follow_up_retrieval, + send_to_expanded_refined_retrieval, ) from onyx.agent_search.answer_question.nodes.answer_check import answer_check from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation @@ -20,7 +20,7 @@ logger = setup_logger() -def answer_follow_up_query_graph_builder() -> StateGraph: +def answer_refined_query_graph_builder() -> StateGraph: graph = StateGraph( state_schema=AnswerQuestionState, input=AnswerQuestionInput, @@ -55,7 +55,7 @@ def answer_follow_up_query_graph_builder() -> StateGraph: graph.add_conditional_edges( source=START, - path=send_to_expanded_follow_up_retrieval, + path=send_to_expanded_refined_retrieval, path_map=["decomposed_follow_up_retrieval"], ) graph.add_edge( @@ -87,7 +87,7 @@ def answer_follow_up_query_graph_builder() -> StateGraph: from onyx.llm.factory import get_default_llms from onyx.context.search.models import SearchRequest - graph = answer_follow_up_query_graph_builder() + graph = answer_refined_query_graph_builder() compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 7c0ac5c110..f10a18982e 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -14,7 +14,9 @@ logger = setup_logger() -def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: +def parallelize_initial_sub_question_answering( + state: MainState, +) -> list[Send | Hashable]: if len(state["initial_decomp_questions"]) > 0: # sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]] # if len(state["sub_question_records"]) == 0: @@ -59,7 +61,9 @@ def continue_to_refined_answer_or_end( return "logging_node" -def parallelize_follow_up_answer_queries(state: MainState) -> list[Send | Hashable]: +def parallelize_refined_sub_question_answering( + state: MainState, +) -> list[Send | Hashable]: if len(state["follow_up_sub_questions"]) > 0: return [ Send( diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 52d8fff155..2da26556bb 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -3,26 +3,26 @@ from langgraph.graph import StateGraph from onyx.agent_search.answer_follow_up_question.graph_builder import ( - answer_follow_up_query_graph_builder, + answer_refined_query_graph_builder, ) from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder from onyx.agent_search.base_raw_search.graph_builder import ( base_raw_search_graph_builder, ) from onyx.agent_search.main.edges import continue_to_refined_answer_or_end -from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries -from onyx.agent_search.main.edges import parallelize_follow_up_answer_queries -from onyx.agent_search.main.nodes import entity_term_extraction -from onyx.agent_search.main.nodes import follow_up_decompose +from onyx.agent_search.main.edges import parallelize_initial_sub_question_answering +from onyx.agent_search.main.edges import parallelize_refined_sub_question_answering +from onyx.agent_search.main.nodes import agent_logging +from onyx.agent_search.main.nodes import entity_term_extraction_llm from onyx.agent_search.main.nodes import generate_initial_answer from onyx.agent_search.main.nodes import generate_refined_answer -from onyx.agent_search.main.nodes import ingest_answers from onyx.agent_search.main.nodes import ingest_follow_up_answers -from onyx.agent_search.main.nodes import ingest_initial_retrieval +from onyx.agent_search.main.nodes import ingest_initial_base_retrieval +from onyx.agent_search.main.nodes import ingest_initial_sub_question_answers from onyx.agent_search.main.nodes import initial_answer_quality_check -from onyx.agent_search.main.nodes import logging_node -from onyx.agent_search.main.nodes import main_decomp_base +from onyx.agent_search.main.nodes import initial_sub_question_creation from onyx.agent_search.main.nodes import refined_answer_decision +from onyx.agent_search.main.nodes import refined_sub_question_creation from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState from onyx.agent_search.shared_graph_utils.utils import get_test_config @@ -48,7 +48,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_node( node="base_decomp", - action=main_decomp_base, + action=initial_sub_question_creation, ) answer_query_subgraph = answer_query_graph_builder().compile() graph.add_node( @@ -78,7 +78,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: # ) graph.add_node( node="ingest_answers", - action=ingest_answers, + action=ingest_initial_sub_question_answers, ) graph.add_node( node="generate_initial_answer", @@ -134,7 +134,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: ) graph.add_conditional_edges( source="base_decomp", - path=parallelize_decompozed_answer_queries, + path=parallelize_initial_sub_question_answering, path_map=["answer_query"], ) graph.add_edge( @@ -206,7 +206,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: ) graph.add_node( node="ingest_initial_retrieval", - action=ingest_initial_retrieval, + action=ingest_initial_base_retrieval, ) # graph.add_node( # node="ingest_answers", @@ -312,7 +312,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: else: graph.add_node( node="base_decomp", - action=main_decomp_base, + action=initial_sub_question_creation, ) answer_query_subgraph = answer_query_graph_builder().compile() graph.add_node( @@ -334,10 +334,10 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_node( node="follow_up_decompose", - action=follow_up_decompose, + action=refined_sub_question_creation, ) - answer_follow_up_question = answer_follow_up_query_graph_builder().compile() + answer_follow_up_question = answer_refined_query_graph_builder().compile() graph.add_node( node="answer_follow_up_question", action=answer_follow_up_question, @@ -360,11 +360,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_node( node="ingest_initial_retrieval", - action=ingest_initial_retrieval, + action=ingest_initial_base_retrieval, ) graph.add_node( node="ingest_answers", - action=ingest_answers, + action=ingest_initial_sub_question_answers, ) graph.add_node( node="generate_initial_answer", @@ -378,7 +378,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_node( node="entity_term_extraction", - action=entity_term_extraction, + action=entity_term_extraction_llm, ) graph.add_node( node="refined_answer_decision", @@ -387,7 +387,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_node( node="logging_node", - action=logging_node, + action=agent_logging, ) # if test_mode: # graph.add_node( @@ -410,7 +410,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: ) graph.add_conditional_edges( source="base_decomp", - path=parallelize_decompozed_answer_queries, + path=parallelize_initial_sub_question_answering, path_map=["answer_query"], ) graph.add_edge( @@ -446,7 +446,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_conditional_edges( source="follow_up_decompose", - path=parallelize_follow_up_answer_queries, + path=parallelize_refined_sub_question_answering, path_map=["answer_follow_up_question"], ) graph.add_edge( diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 6e5cd30237..4dc2de6227 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -70,8 +70,8 @@ logger = setup_logger() -def dispatch_subquestion(level: int) -> Callable[[str, int], None]: - def helper(sub_question_part: str, num: int) -> None: +def _dispatch_subquestion(level: int) -> Callable[[str, int], None]: + def _helper(sub_question_part: str, num: int) -> None: dispatch_custom_event( "decomp_qs", SubQuestionPiece( @@ -81,10 +81,10 @@ def helper(sub_question_part: str, num: int) -> None: ), ) - return helper + return _helper -def main_decomp_base(state: MainState) -> BaseDecompUpdate: +def initial_sub_question_creation(state: MainState) -> BaseDecompUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------BASE DECOMP START---") @@ -109,7 +109,7 @@ def main_decomp_base(state: MainState) -> BaseDecompUpdate: model = state["fast_llm"] # dispatches custom events for subquestion tokens, adding in subquestion ids. - streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0)) + streamed_tokens = dispatch_separated(model.stream(msg), _dispatch_subquestion(0)) response = merge_content(*streamed_tokens) @@ -398,7 +398,7 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate return InitialAnswerQualityUpdate(initial_answer_quality=verdict) -def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: +def entity_term_extraction_llm(state: MainState) -> EntityTermExtractionUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------GENERATE ENTITIES & TERMS---") @@ -494,7 +494,9 @@ def entity_term_extraction(state: MainState) -> EntityTermExtractionUpdate: ) -def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: +def generate_initial_base_search_only_answer( + state: MainState, +) -> InitialAnswerBASEUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---") @@ -525,7 +527,9 @@ def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate: return InitialAnswerBASEUpdate(initial_base_answer=answer) -def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: +def ingest_initial_sub_question_answers( + state: AnswerQuestionOutput, +) -> DecompAnswersUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------INGEST ANSWERS---") @@ -548,7 +552,9 @@ def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: ) -def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate: +def ingest_initial_base_retrieval( + state: BaseRawSearchOutput, +) -> ExpandedRetrievalUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---") @@ -803,7 +809,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: ) -def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: +def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdate: """ """ now_start = datetime.now() @@ -847,7 +853,7 @@ def follow_up_decompose(state: MainState) -> FollowUpSubQuestionsUpdate: # Grader model = state["fast_llm"] - streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1)) + streamed_tokens = dispatch_separated(model.stream(msg), _dispatch_subquestion(1)) response = merge_content(*streamed_tokens) if isinstance(response, str): @@ -905,7 +911,7 @@ def ingest_follow_up_answers( ) -def logging_node(state: MainState) -> MainOutput: +def agent_logging(state: MainState) -> MainOutput: now_start = datetime.now() logger.info(f"--------{now_start}--------LOGGING NODE---") From 295417f85dd04e539644ae5f97a1cc81d0522e98 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 7 Jan 2025 13:50:36 -0800 Subject: [PATCH 61/78] more renamings.. --- .../edges.py | 2 +- .../graph_builder.py | 26 +- .../models.py | 0 .../nodes/answer_check.py | 4 +- .../nodes/answer_generation.py | 4 +- .../nodes/format_answer.py | 6 +- .../nodes/ingest_retrieval.py | 4 +- .../states.py | 2 +- .../edges.py | 2 +- .../graph_builder.py | 24 +- .../models.py | 0 backend/onyx/agent_search/main/edges.py | 8 +- .../onyx/agent_search/main/graph_builder.py | 581 +++++------------- backend/onyx/agent_search/main/nodes.py | 6 +- backend/onyx/agent_search/main/states.py | 2 +- .../shared_graph_utils/operators.py | 2 +- backend/onyx/db/chat.py | 2 +- 17 files changed, 213 insertions(+), 462 deletions(-) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/edges.py (89%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/graph_builder.py (78%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/models.py (100%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/nodes/answer_check.py (81%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/nodes/answer_generation.py (93%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/nodes/format_answer.py (69%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/nodes/ingest_retrieval.py (88%) rename backend/onyx/agent_search/{answer_question => answer_initial_sub_question}/states.py (95%) rename backend/onyx/agent_search/{answer_follow_up_question => answer_refinement_sub_question}/edges.py (89%) rename backend/onyx/agent_search/{answer_follow_up_question => answer_refinement_sub_question}/graph_builder.py (78%) rename backend/onyx/agent_search/{answer_follow_up_question => answer_refinement_sub_question}/models.py (100%) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_initial_sub_question/edges.py similarity index 89% rename from backend/onyx/agent_search/answer_question/edges.py rename to backend/onyx/agent_search/answer_initial_sub_question/edges.py index 569f6437c3..1f095367ee 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.utils.logger import setup_logger diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py similarity index 78% rename from backend/onyx/agent_search/answer_question/graph_builder.py rename to backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py index 9526853c3c..bfcc984912 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py @@ -2,14 +2,24 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_question.edges import send_to_expanded_retrieval -from onyx.agent_search.answer_question.nodes.answer_check import answer_check -from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation -from onyx.agent_search.answer_question.nodes.format_answer import format_answer -from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval -from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_initial_sub_question.edges import ( + send_to_expanded_retrieval, +) +from onyx.agent_search.answer_initial_sub_question.nodes.answer_check import ( + answer_check, +) +from onyx.agent_search.answer_initial_sub_question.nodes.answer_generation import ( + answer_generation, +) +from onyx.agent_search.answer_initial_sub_question.nodes.format_answer import ( + format_answer, +) +from onyx.agent_search.answer_initial_sub_question.nodes.ingest_retrieval import ( + ingest_retrieval, +) +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/answer_question/models.py b/backend/onyx/agent_search/answer_initial_sub_question/models.py similarity index 100% rename from backend/onyx/agent_search/answer_question/models.py rename to backend/onyx/agent_search/answer_initial_sub_question/models.py diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_check.py similarity index 81% rename from backend/onyx/agent_search/answer_question/nodes/answer_check.py rename to backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_check.py index 6349552f34..ab0481191d 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_check.py @@ -1,8 +1,8 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QACheckUpdate +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState +from onyx.agent_search.answer_initial_sub_question.states import QACheckUpdate from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py similarity index 93% rename from backend/onyx/agent_search/answer_question/nodes/answer_generation.py rename to backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py index 6ab6ac8978..9b4ac4871c 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py @@ -4,8 +4,8 @@ from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QAGenerationUpdate +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState +from onyx.agent_search.answer_initial_sub_question.states import QAGenerationUpdate from onyx.agent_search.shared_graph_utils.agent_prompt_ops import ( build_sub_question_answer_prompt, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_initial_sub_question/nodes/format_answer.py similarity index 69% rename from backend/onyx/agent_search/answer_question/nodes/format_answer.py rename to backend/onyx/agent_search/answer_initial_sub_question/nodes/format_answer.py index 23ffd23939..de96f60614 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/nodes/format_answer.py @@ -1,6 +1,6 @@ -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import QuestionAnswerResults +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState +from onyx.agent_search.answer_initial_sub_question.states import QuestionAnswerResults def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_initial_sub_question/nodes/ingest_retrieval.py similarity index 88% rename from backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py rename to backend/onyx/agent_search/answer_initial_sub_question/nodes/ingest_retrieval.py index cc9e5989ff..dd74ba8d56 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/nodes/ingest_retrieval.py @@ -1,4 +1,6 @@ -from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate +from onyx.agent_search.answer_initial_sub_question.states import ( + RetrievalIngestionUpdate, +) from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.shared_graph_utils.models import AgentChunkStats diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_initial_sub_question/states.py similarity index 95% rename from backend/onyx/agent_search/answer_question/states.py rename to backend/onyx/agent_search/answer_initial_sub_question/states.py index 80bdaa80f4..74da2355ce 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/states.py @@ -2,7 +2,7 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_question.models import QuestionAnswerResults +from onyx.agent_search.answer_initial_sub_question.models import QuestionAnswerResults from onyx.agent_search.core_state import SubgraphCoreState from onyx.agent_search.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats diff --git a/backend/onyx/agent_search/answer_follow_up_question/edges.py b/backend/onyx/agent_search/answer_refinement_sub_question/edges.py similarity index 89% rename from backend/onyx/agent_search/answer_follow_up_question/edges.py rename to backend/onyx/agent_search/answer_refinement_sub_question/edges.py index 36d7a17308..9be3a71858 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/edges.py +++ b/backend/onyx/agent_search/answer_refinement_sub_question/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput from onyx.agent_search.core_state import in_subgraph_extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.utils.logger import setup_logger diff --git a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py b/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py similarity index 78% rename from backend/onyx/agent_search/answer_follow_up_question/graph_builder.py rename to backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py index 7f2766888f..5cf9897f08 100644 --- a/backend/onyx/agent_search/answer_follow_up_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py @@ -2,16 +2,24 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_follow_up_question.edges import ( +from onyx.agent_search.answer_initial_sub_question.nodes.answer_check import ( + answer_check, +) +from onyx.agent_search.answer_initial_sub_question.nodes.answer_generation import ( + answer_generation, +) +from onyx.agent_search.answer_initial_sub_question.nodes.format_answer import ( + format_answer, +) +from onyx.agent_search.answer_initial_sub_question.nodes.ingest_retrieval import ( + ingest_retrieval, +) +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState +from onyx.agent_search.answer_refinement_sub_question.edges import ( send_to_expanded_refined_retrieval, ) -from onyx.agent_search.answer_question.nodes.answer_check import answer_check -from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation -from onyx.agent_search.answer_question.nodes.format_answer import format_answer -from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval -from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/answer_follow_up_question/models.py b/backend/onyx/agent_search/answer_refinement_sub_question/models.py similarity index 100% rename from backend/onyx/agent_search/answer_follow_up_question/models.py rename to backend/onyx/agent_search/answer_refinement_sub_question/models.py diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index f10a18982e..880c5d6070 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -3,8 +3,8 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput from onyx.agent_search.core_state import extract_core_fields_for_subgraph from onyx.agent_search.main.states import MainState from onyx.agent_search.main.states import RequireRefinedAnswerUpdate @@ -30,7 +30,7 @@ def parallelize_initial_sub_question_answering( return [ Send( - "answer_query", + "answer_query_subgraph", AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question, @@ -67,7 +67,7 @@ def parallelize_refined_sub_question_answering( if len(state["follow_up_sub_questions"]) > 0: return [ Send( - "answer_follow_up_question", + "answer_refinement_sub_question", AnswerQuestionInput( **extract_core_fields_for_subgraph(state), question=question_data.sub_question, diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 2da26556bb..5e45a7fffc 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -2,10 +2,12 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_follow_up_question.graph_builder import ( +from onyx.agent_search.answer_initial_sub_question.graph_builder import ( + answer_query_graph_builder, +) +from onyx.agent_search.answer_refinement_sub_question.graph_builder import ( answer_refined_query_graph_builder, ) -from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder from onyx.agent_search.base_raw_search.graph_builder import ( base_raw_search_graph_builder, ) @@ -16,9 +18,9 @@ from onyx.agent_search.main.nodes import entity_term_extraction_llm from onyx.agent_search.main.nodes import generate_initial_answer from onyx.agent_search.main.nodes import generate_refined_answer -from onyx.agent_search.main.nodes import ingest_follow_up_answers from onyx.agent_search.main.nodes import ingest_initial_base_retrieval from onyx.agent_search.main.nodes import ingest_initial_sub_question_answers +from onyx.agent_search.main.nodes import ingest_refined_answers from onyx.agent_search.main.nodes import initial_answer_quality_check from onyx.agent_search.main.nodes import initial_sub_question_creation from onyx.agent_search.main.nodes import refined_answer_decision @@ -39,456 +41,185 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: input=MainInput, ) - graph_component = "both" - # graph_component = "right" - # graph_component = "left" - - if graph_component == "left": - ### Add nodes ### - - graph.add_node( - node="base_decomp", - action=initial_sub_question_creation, - ) - answer_query_subgraph = answer_query_graph_builder().compile() - graph.add_node( - node="answer_query", - action=answer_query_subgraph, - ) - - # graph.add_node( - # node="prep_for_initial_retrieval", - # action=prep_for_initial_retrieval, - # ) - - # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() - # graph.add_node( - # node="initial_retrieval", - # action=expanded_retrieval_subgraph, - # ) - - # base_raw_search_subgraph = base_raw_search_graph_builder().compile() - # graph.add_node( - # node="base_raw_search_data", - # action=base_raw_search_subgraph, - # ) - # graph.add_node( - # node="ingest_initial_retrieval", - # action=ingest_initial_retrieval, - # ) - graph.add_node( - node="ingest_answers", - action=ingest_initial_sub_question_answers, - ) - graph.add_node( - node="generate_initial_answer", - action=generate_initial_answer, - ) - # if test_mode: - # graph.add_node( - # node="generate_initial_base_answer", - # action=generate_initial_base_answer, - # ) - - ### Add edges ### - - # graph.add_conditional_edges( - # source=START, - # path=send_to_initial_retrieval, - # path_map=["initial_retrieval"], - # ) - - # graph.add_edge( - # start_key=START, - # end_key="prep_for_initial_retrieval", - # ) - # graph.add_edge( - # start_key="prep_for_initial_retrieval", - # end_key="initial_retrieval", - # ) - # graph.add_edge( - # start_key="initial_retrieval", - # end_key="ingest_initial_retrieval", - # ) - - # graph.add_edge( - # start_key=START, - # end_key="base_raw_search_data" - # ) - - # # graph.add_edge( - # # start_key="base_raw_search_data", - # # end_key=END - # # ) - # graph.add_edge( - # start_key="base_raw_search_data", - # end_key="ingest_initial_retrieval", - # ) - # graph.add_edge( - # start_key="ingest_initial_retrieval", - # end_key=END - # ) - graph.add_edge( - start_key=START, - end_key="base_decomp", - ) - graph.add_conditional_edges( - source="base_decomp", - path=parallelize_initial_sub_question_answering, - path_map=["answer_query"], - ) - graph.add_edge( - start_key="answer_query", - end_key="ingest_answers", - ) - - graph.add_edge( - start_key="ingest_answers", - end_key="generate_initial_answer", - ) - - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_answer", - # ) - - graph.add_edge( - start_key="generate_initial_answer", - end_key=END, - ) - # graph.add_edge( - # start_key="ingest_answers", - # end_key="generate_initial_answer", - # ) - # if test_mode: - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_base_answer", - # ) - # graph.add_edge( - # start_key=["generate_initial_answer", "generate_initial_base_answer"], - # end_key=END, - # ) - # else: - # graph.add_edge( - # start_key="generate_initial_answer", - # end_key=END, - # ) - - elif graph_component == "right": - ### Add nodes ### - - # graph.add_node( - # node="base_decomp", - # action=main_decomp_base, - # ) - # answer_query_subgraph = answer_query_graph_builder().compile() - # graph.add_node( - # node="answer_query", - # action=answer_query_subgraph, - # ) - - # graph.add_node( - # node="prep_for_initial_retrieval", - # action=prep_for_initial_retrieval, - # ) - - # expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() - # graph.add_node( - # node="initial_retrieval", - # action=expanded_retrieval_subgraph, - # ) - - base_raw_search_subgraph = base_raw_search_graph_builder().compile() - graph.add_node( - node="base_raw_search_data", - action=base_raw_search_subgraph, - ) - graph.add_node( - node="ingest_initial_retrieval", - action=ingest_initial_base_retrieval, - ) - # graph.add_node( - # node="ingest_answers", - # action=ingest_answers, - # ) - graph.add_node( - node="generate_initial_answer", - action=generate_initial_answer, - ) - # if test_mode: - # graph.add_node( - # node="generate_initial_base_answer", - # action=generate_initial_base_answer, - # ) - - ### Add edges ### - - # graph.add_conditional_edges( - # source=START, - # path=send_to_initial_retrieval, - # path_map=["initial_retrieval"], - # ) - - # graph.add_edge( - # start_key=START, - # end_key="prep_for_initial_retrieval", - # ) - # graph.add_edge( - # start_key="prep_for_initial_retrieval", - # end_key="initial_retrieval", - # ) - # graph.add_edge( - # start_key="initial_retrieval", - # end_key="ingest_initial_retrieval", - # ) - - graph.add_edge(start_key=START, end_key="base_raw_search_data") - - # # graph.add_edge( - # # start_key="base_raw_search_data", - # # end_key=END - # # ) - graph.add_edge( - start_key="base_raw_search_data", - end_key="ingest_initial_retrieval", - ) - # graph.add_edge( - # start_key="ingest_initial_retrieval", - # end_key=END - # ) - # graph.add_edge( - # start_key=START, - # end_key="base_decomp", - # ) - # graph.add_conditional_edges( - # source="base_decomp", - # path=parallelize_decompozed_answer_queries, - # path_map=["answer_query"], - # ) - # graph.add_edge( - # start_key="answer_query", - # end_key="ingest_answers", - # ) - - # graph.add_edge( - # start_key="ingest_answers", - # end_key="generate_initial_answer", - # ) - - graph.add_edge( - start_key="ingest_initial_retrieval", - end_key="generate_initial_answer", - ) - - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_answer", - # ) - - graph.add_edge( - start_key="generate_initial_answer", - end_key=END, - ) - # graph.add_edge( - # start_key="ingest_answers", - # end_key="generate_initial_answer", - # ) - # if test_mode: - # graph.add_edge( - # start_key=["ingest_answers", "ingest_initial_retrieval"], - # end_key="generate_initial_base_answer", - # ) - # graph.add_edge( - # start_key=["generate_initial_answer", "generate_initial_base_answer"], - # end_key=END, - # ) - # else: - # graph.add_edge( - # start_key="generate_initial_answer", - # end_key=END, - # ) - - else: - graph.add_node( - node="base_decomp", - action=initial_sub_question_creation, - ) - answer_query_subgraph = answer_query_graph_builder().compile() - graph.add_node( - node="answer_query", - action=answer_query_subgraph, - ) + graph.add_node( + node="initial_sub_question_creation", + action=initial_sub_question_creation, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query_subgraph", + action=answer_query_subgraph, + ) - base_raw_search_subgraph = base_raw_search_graph_builder().compile() - graph.add_node( - node="base_raw_search_data", - action=base_raw_search_subgraph, - ) + base_raw_search_subgraph = base_raw_search_graph_builder().compile() + graph.add_node( + node="base_raw_search_subgraph", + action=base_raw_search_subgraph, + ) - # refined_answer_subgraph = refined_answers_graph_builder().compile() - # graph.add_node( - # node="refined_answer_subgraph", - # action=refined_answer_subgraph, - # ) + # refined_answer_subgraph = refined_answers_graph_builder().compile() + # graph.add_node( + # node="refined_answer_subgraph", + # action=refined_answer_subgraph, + # ) - graph.add_node( - node="follow_up_decompose", - action=refined_sub_question_creation, - ) + graph.add_node( + node="refined_sub_question_creation", + action=refined_sub_question_creation, + ) - answer_follow_up_question = answer_refined_query_graph_builder().compile() - graph.add_node( - node="answer_follow_up_question", - action=answer_follow_up_question, - ) + answer_refined_question = answer_refined_query_graph_builder().compile() + graph.add_node( + node="answer_refined_question", + action=answer_refined_question, + ) - graph.add_node( - node="ingest_follow_up_answers", - action=ingest_follow_up_answers, - ) + graph.add_node( + node="ingest_refined_answers", + action=ingest_refined_answers, + ) - graph.add_node( - node="generate_refined_answer", - action=generate_refined_answer, - ) + graph.add_node( + node="generate_refined_answer", + action=generate_refined_answer, + ) - # graph.add_node( - # node="check_refined_answer", - # action=check_refined_answer, - # ) + # graph.add_node( + # node="check_refined_answer", + # action=check_refined_answer, + # ) - graph.add_node( - node="ingest_initial_retrieval", - action=ingest_initial_base_retrieval, - ) - graph.add_node( - node="ingest_answers", - action=ingest_initial_sub_question_answers, - ) - graph.add_node( - node="generate_initial_answer", - action=generate_initial_answer, - ) + graph.add_node( + node="ingest_initial_retrieval", + action=ingest_initial_base_retrieval, + ) + graph.add_node( + node="ingest_initial_sub_question_answers", + action=ingest_initial_sub_question_answers, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) - graph.add_node( - node="initial_answer_quality_check", - action=initial_answer_quality_check, - ) + graph.add_node( + node="initial_answer_quality_check", + action=initial_answer_quality_check, + ) - graph.add_node( - node="entity_term_extraction", - action=entity_term_extraction_llm, - ) - graph.add_node( - node="refined_answer_decision", - action=refined_answer_decision, - ) + graph.add_node( + node="entity_term_extraction_llm", + action=entity_term_extraction_llm, + ) + graph.add_node( + node="refined_answer_decision", + action=refined_answer_decision, + ) - graph.add_node( - node="logging_node", - action=agent_logging, - ) - # if test_mode: - # graph.add_node( - # node="generate_initial_base_answer", - # action=generate_initial_base_answer, - # ) + graph.add_node( + node="logging_node", + action=agent_logging, + ) + # if test_mode: + # graph.add_node( + # node="generate_initial_base_answer", + # action=generate_initial_base_answer, + # ) - ### Add edges ### + ### Add edges ### - graph.add_edge(start_key=START, end_key="base_raw_search_data") + graph.add_edge(start_key=START, end_key="base_raw_search_subgraph") - graph.add_edge( - start_key="base_raw_search_data", - end_key="ingest_initial_retrieval", - ) + graph.add_edge( + start_key="base_raw_search_subgraph", + end_key="ingest_initial_retrieval", + ) - graph.add_edge( - start_key=START, - end_key="base_decomp", - ) - graph.add_conditional_edges( - source="base_decomp", - path=parallelize_initial_sub_question_answering, - path_map=["answer_query"], - ) - graph.add_edge( - start_key="answer_query", - end_key="ingest_answers", - ) + graph.add_edge( + start_key=START, + end_key="initial_sub_question_creation", + ) + graph.add_conditional_edges( + source="initial_sub_question_creation", + path=parallelize_initial_sub_question_answering, + path_map=["answer_query_subgraph"], + ) + graph.add_edge( + start_key="answer_query_subgraph", + end_key="ingest_initial_sub_question_answers", + ) - graph.add_edge( - start_key=["ingest_answers", "ingest_initial_retrieval"], - end_key="generate_initial_answer", - ) + graph.add_edge( + start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"], + end_key="generate_initial_answer", + ) - graph.add_edge( - start_key=["ingest_answers", "ingest_initial_retrieval"], - end_key="entity_term_extraction", - ) + graph.add_edge( + start_key=["ingest_initial_sub_question_answers", "ingest_initial_retrieval"], + end_key="entity_term_extraction_llm", + ) - graph.add_edge( - start_key="generate_initial_answer", - end_key="initial_answer_quality_check", - ) + graph.add_edge( + start_key="generate_initial_answer", + end_key="initial_answer_quality_check", + ) - graph.add_edge( - start_key=["initial_answer_quality_check", "entity_term_extraction"], - end_key="refined_answer_decision", - ) + graph.add_edge( + start_key=["initial_answer_quality_check", "entity_term_extraction_llm"], + end_key="refined_answer_decision", + ) - graph.add_conditional_edges( - source="refined_answer_decision", - path=continue_to_refined_answer_or_end, - path_map=["follow_up_decompose", "logging_node"], - ) + graph.add_conditional_edges( + source="refined_answer_decision", + path=continue_to_refined_answer_or_end, + path_map=["refined_sub_question_creation", "logging_node"], + ) - graph.add_conditional_edges( - source="follow_up_decompose", - path=parallelize_refined_sub_question_answering, - path_map=["answer_follow_up_question"], - ) - graph.add_edge( - start_key="answer_follow_up_question", - end_key="ingest_follow_up_answers", - ) + graph.add_conditional_edges( + source="refined_sub_question_creation", + path=parallelize_refined_sub_question_answering, + path_map=["answer_refined_question"], + ) + graph.add_edge( + start_key="answer_refined_question", + end_key="ingest_refined_answers", + ) - graph.add_edge( - start_key="ingest_follow_up_answers", - end_key="generate_refined_answer", - ) + graph.add_edge( + start_key="ingest_refined_answers", + end_key="generate_refined_answer", + ) - # graph.add_conditional_edges( - # source="refined_answer_decision", - # path=continue_to_refined_answer_or_end, - # path_map=["refined_answer_subgraph", END], - # ) + # graph.add_conditional_edges( + # source="refined_answer_decision", + # path=continue_to_refined_answer_or_end, + # path_map=["refined_answer_subgraph", END], + # ) - # graph.add_edge( - # start_key="refined_answer_subgraph", - # end_key="generate_refined_answer", - # ) + # graph.add_edge( + # start_key="refined_answer_subgraph", + # end_key="generate_refined_answer", + # ) - graph.add_edge( - start_key="generate_refined_answer", - end_key="logging_node", - ) + graph.add_edge( + start_key="generate_refined_answer", + end_key="logging_node", + ) - graph.add_edge( - start_key="logging_node", - end_key=END, - ) + graph.add_edge( + start_key="logging_node", + end_key=END, + ) - # graph.add_edge( - # start_key="generate_refined_answer", - # end_key="check_refined_answer", - # ) + # graph.add_edge( + # start_key="generate_refined_answer", + # end_key="check_refined_answer", + # ) - # graph.add_edge( - # start_key="check_refined_answer", - # end_key=END, - # ) + # graph.add_edge( + # start_key="check_refined_answer", + # end_key=END, + # ) return graph diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 4dc2de6227..0e616df01d 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -10,8 +10,8 @@ from langchain_core.messages import merge_content from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_question.states import QuestionAnswerResults +from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_initial_sub_question.states import QuestionAnswerResults from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput from onyx.agent_search.main.models import AgentAdditionalMetrics from onyx.agent_search.main.models import AgentBaseMetrics @@ -885,7 +885,7 @@ def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdat ) -def ingest_follow_up_answers( +def ingest_refined_answers( state: AnswerQuestionOutput, ) -> DecompAnswersUpdate: now_start = datetime.now() diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index a12e17a9ff..ee3f637c02 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -3,7 +3,7 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_question.states import QuestionAnswerResults +from onyx.agent_search.answer_initial_sub_question.states import QuestionAnswerResults from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.models import QueryResult diff --git a/backend/onyx/agent_search/shared_graph_utils/operators.py b/backend/onyx/agent_search/shared_graph_utils/operators.py index bd28d64165..b3636b81ca 100644 --- a/backend/onyx/agent_search/shared_graph_utils/operators.py +++ b/backend/onyx/agent_search/shared_graph_utils/operators.py @@ -1,4 +1,4 @@ -from onyx.agent_search.answer_question.models import QuestionAnswerResults +from onyx.agent_search.answer_initial_sub_question.models import QuestionAnswerResults from onyx.chat.prune_and_merge import _merge_sections from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index b03f615692..b31c0f29df 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -16,7 +16,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session -from onyx.agent_search.answer_question.models import QuestionAnswerResults +from onyx.agent_search.answer_initial_sub_question.models import QuestionAnswerResults from onyx.agent_search.main.models import CombinedAgentMetrics from onyx.auth.schemas import UserRole from onyx.chat.models import DocumentRelevance From ef67f9cd1e3622e48ed22fec01309a45a6f7c1a3 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 7 Jan 2025 14:00:56 -0800 Subject: [PATCH 62/78] more renamings follow_up -> refined --- .../answer_initial_sub_question/edges.py | 2 +- .../graph_builder.py | 6 ++-- .../answer_refinement_sub_question/edges.py | 2 +- .../graph_builder.py | 30 +++++++++---------- backend/onyx/agent_search/main/edges.py | 10 +++---- backend/onyx/agent_search/main/nodes.py | 12 ++++---- backend/onyx/agent_search/main/states.py | 6 ++-- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/backend/onyx/agent_search/answer_initial_sub_question/edges.py b/backend/onyx/agent_search/answer_initial_sub_question/edges.py index 1f095367ee..62cc9a93fa 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/edges.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/edges.py @@ -14,7 +14,7 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: logger.info("sending to expanded retrieval via edge") return Send( - "decomped_expanded_retrieval", + "initial_sub_question_expanded_retrieval", ExpandedRetrievalInput( **in_subgraph_extract_core_fields(state), question=state["question"], diff --git a/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py b/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py index bfcc984912..58e914e020 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py @@ -40,7 +40,7 @@ def answer_query_graph_builder() -> StateGraph: expanded_retrieval = expanded_retrieval_graph_builder().compile() graph.add_node( - node="decomped_expanded_retrieval", + node="initial_sub_question_expanded_retrieval", action=expanded_retrieval, ) graph.add_node( @@ -65,10 +65,10 @@ def answer_query_graph_builder() -> StateGraph: graph.add_conditional_edges( source=START, path=send_to_expanded_retrieval, - path_map=["decomped_expanded_retrieval"], + path_map=["initial_sub_question_expanded_retrieval"], ) graph.add_edge( - start_key="decomped_expanded_retrieval", + start_key="initial_sub_question_expanded_retrieval", end_key="ingest_retrieval", ) graph.add_edge( diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/edges.py b/backend/onyx/agent_search/answer_refinement_sub_question/edges.py index 9be3a71858..41059136ea 100644 --- a/backend/onyx/agent_search/answer_refinement_sub_question/edges.py +++ b/backend/onyx/agent_search/answer_refinement_sub_question/edges.py @@ -14,7 +14,7 @@ def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Has logger.info("sending to expanded retrieval for follow up question via edge") return Send( - "decomposed_follow_up_retrieval", + "refined_sub_question_expanded_retrieval", ExpandedRetrievalInput( **in_subgraph_extract_core_fields(state), question=state["question"], diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py b/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py index 5cf9897f08..15fe7156a4 100644 --- a/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py @@ -39,23 +39,23 @@ def answer_refined_query_graph_builder() -> StateGraph: expanded_retrieval = expanded_retrieval_graph_builder().compile() graph.add_node( - node="decomposed_follow_up_retrieval", + node="refined_sub_question_expanded_retrieval", action=expanded_retrieval, ) graph.add_node( - node="follow_up_answer_check", + node="refined_sub_answer_check", action=answer_check, ) graph.add_node( - node="follow_up_answer_generation", + node="refined_sub_answer_generation", action=answer_generation, ) graph.add_node( - node="format_follow_up_answer", + node="format_refined_sub_answer", action=format_answer, ) graph.add_node( - node="ingest_follow_up_retrieval", + node="ingest_refined_retrieval", action=ingest_retrieval, ) @@ -64,26 +64,26 @@ def answer_refined_query_graph_builder() -> StateGraph: graph.add_conditional_edges( source=START, path=send_to_expanded_refined_retrieval, - path_map=["decomposed_follow_up_retrieval"], + path_map=["refined_sub_question_expanded_retrieval"], ) graph.add_edge( - start_key="decomposed_follow_up_retrieval", - end_key="ingest_follow_up_retrieval", + start_key="refined_sub_question_expanded_retrieval", + end_key="ingest_refined_retrieval", ) graph.add_edge( - start_key="ingest_follow_up_retrieval", - end_key="follow_up_answer_generation", + start_key="ingest_refined_retrieval", + end_key="refined_sub_answer_generation", ) graph.add_edge( - start_key="follow_up_answer_generation", - end_key="follow_up_answer_check", + start_key="refined_sub_answer_generation", + end_key="refined_sub_answer_check", ) graph.add_edge( - start_key="follow_up_answer_check", - end_key="format_follow_up_answer", + start_key="refined_sub_answer_check", + end_key="format_refined_sub_answer", ) graph.add_edge( - start_key="format_follow_up_answer", + start_key="format_refined_sub_answer", end_key=END, ) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 880c5d6070..f36c808de8 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -54,9 +54,9 @@ def parallelize_initial_sub_question_answering( # Define the function that determines whether to continue or not def continue_to_refined_answer_or_end( state: RequireRefinedAnswerUpdate, -) -> Literal["follow_up_decompose", "logging_node"]: +) -> Literal["refined_decompose", "logging_node"]: if state["require_refined_answer"]: - return "follow_up_decompose" + return "refined_decompose" else: return "logging_node" @@ -64,7 +64,7 @@ def continue_to_refined_answer_or_end( def parallelize_refined_sub_question_answering( state: MainState, ) -> list[Send | Hashable]: - if len(state["follow_up_sub_questions"]) > 0: + if len(state["refined_sub_questions"]) > 0: return [ Send( "answer_refinement_sub_question", @@ -74,13 +74,13 @@ def parallelize_refined_sub_question_answering( question_id=make_question_id(1, question_nr), ), ) - for question_nr, question_data in state["follow_up_sub_questions"].items() + for question_nr, question_data in state["refined_sub_questions"].items() ] else: return [ Send( - "ingest_follow_up_answers", + "ingest_refined_sub_answers", AnswerQuestionOutput( answer_results=[], ), diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 0e616df01d..515a65ae44 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -611,7 +611,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: persona_prompt = get_persona_prompt(state["config"].search_request.persona) initial_documents = state["documents"] - revised_documents = state["follow_up_documents"] + revised_documents = state["refined_documents"] combined_documents = dedup_inference_sections(initial_documents, revised_documents) @@ -623,7 +623,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: revision_doc_effectiveness = 10.0 decomp_answer_results = state["decomp_answer_results"] - # revised_answer_results = state["follow_up_decomp_answer_results"] + # revised_answer_results = state["refined_decomp_answer_results"] good_qa_list: list[str] = [] decomp_questions = [] @@ -861,9 +861,9 @@ def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdat else: raise ValueError("LLM response is not a string") - follow_up_sub_question_dict = {} + refined_sub_question_dict = {} for sub_question_nr, sub_question in enumerate(parsed_response): - follow_up_sub_question = FollowUpSubQuestion( + refined_sub_question = FollowUpSubQuestion( sub_question=sub_question, sub_question_id=make_question_id(1, sub_question_nr), verified=False, @@ -871,7 +871,7 @@ def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdat answer="", ) - follow_up_sub_question_dict[sub_question_nr] = follow_up_sub_question + refined_sub_question_dict[sub_question_nr] = refined_sub_question now_end = datetime.now() @@ -880,7 +880,7 @@ def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdat ) return FollowUpSubQuestionsUpdate( - follow_up_sub_questions=follow_up_sub_question_dict, + refined_sub_questions=refined_sub_question_dict, agent_refined_start_time=agent_refined_start_time, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index ee3f637c02..18357bcf9a 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -69,8 +69,8 @@ class DecompAnswersUpdate(TypedDict): class FollowUpDecompAnswersUpdate(TypedDict): - follow_up_documents: Annotated[list[InferenceSection], dedup_inference_sections] - follow_up_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] + refined_documents: Annotated[list[InferenceSection], dedup_inference_sections] + refined_decomp_answer_results: Annotated[list[QuestionAnswerResults], add] class ExpandedRetrievalUpdate(TypedDict): @@ -86,7 +86,7 @@ class EntityTermExtractionUpdate(TypedDict): class FollowUpSubQuestionsUpdate(TypedDict): - follow_up_sub_questions: dict[int, FollowUpSubQuestion] + refined_sub_questions: dict[int, FollowUpSubQuestion] agent_refined_start_time: datetime | None From 682b145a6a183d1b73f627c63eb0297d15a737cd Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Tue, 7 Jan 2025 15:13:34 -0800 Subject: [PATCH 63/78] add message history to pro search config --- backend/onyx/chat/models.py | 4 ++++ backend/onyx/chat/process_message.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index f27ec16d2e..60a50089ab 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -18,6 +18,7 @@ from onyx.context.search.enums import SearchType from onyx.context.search.models import RetrievalDocs from onyx.context.search.models import SearchRequest +from onyx.llm.models import PreviousMessage from onyx.llm.override_models import PromptOverride from onyx.tools.models import ToolCallFinalResult from onyx.tools.models import ToolCallKickoff @@ -224,6 +225,9 @@ class ProSearchConfig(BaseModel): # Whether to allow creation of refinement questions (and entity extraction, etc.) allow_refinement: bool = False + # Message history for the current chat session + message_history: list[PreviousMessage] | None = None + AnswerQuestionPossibleReturn = ( OnyxAnswerPiece diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index dd292a4bb6..429a997cb3 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -688,6 +688,10 @@ def stream_chat_message_objects( for tool_list in tool_dict.values(): tools.extend(tool_list) + message_history = [ + PreviousMessage.from_chat_message(msg, files) for msg in history_msgs + ] + search_request = None pro_search_config = None if new_msg_req.use_pro_search: @@ -714,11 +718,15 @@ def stream_chat_message_objects( else None ), ) + # TODO: Since we're deleting the current main path in Answer, + # we should construct this unconditionally inside Answer instead + # Leaving it here for the time being to avoid breaking changes pro_search_config = ( ProSearchConfig( search_request=search_request, chat_session_id=chat_session_id, message_id=user_message.id if user_message else None, + message_history=message_history, ) if new_msg_req.use_pro_search else None @@ -745,9 +753,7 @@ def stream_chat_message_objects( ) ), fast_llm=fast_llm, - message_history=[ - PreviousMessage.from_chat_message(msg, files) for msg in history_msgs - ], + message_history=message_history, tools=tools, force_use_tool=_get_force_search_settings(new_msg_req, tools), pro_search_config=pro_search_config, From 1996e22f9ba1ae6a7191684a60d671983914dfef Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Tue, 7 Jan 2025 15:16:04 -0800 Subject: [PATCH 64/78] remove duplication of Main State keys + SubQuestion citation --- backend/onyx/agent_search/main/nodes.py | 4 ++ backend/onyx/agent_search/main/states.py | 40 ++++++++++++++----- .../shared_graph_utils/prompts.py | 7 +++- .../agent_search/shared_graph_utils/utils.py | 7 +++- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 515a65ae44..461b14f074 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -263,6 +263,8 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: good_qa_list: list[str] = [] decomp_questions = [] + sub_question_nr = 1 + for decomp_answer_result in decomp_answer_results: decomp_questions.append(decomp_answer_result.question) if ( @@ -274,8 +276,10 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: SUB_QUESTION_ANSWER_TEMPLATE.format( sub_question=decomp_answer_result.question, sub_answer=decomp_answer_result.answer, + sub_question_nr=sub_question_nr, ) ) + sub_question_nr += 1 if len(good_qa_list) > 0: sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 18357bcf9a..3e9bca346d 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -24,15 +24,26 @@ ## Update States -class BaseDecompUpdate(TypedDict): - agent_start_time: datetime +class RefinedAgentStartStats(TypedDict): agent_refined_start_time: datetime | None + + +class RefinedAgentEndStats(TypedDict): agent_refined_end_time: datetime | None agent_refined_metrics: AgentRefinedMetrics + +class BaseDecompUpdateBase(TypedDict): + agent_start_time: datetime initial_decomp_questions: list[str] +class BaseDecompUpdate( + RefinedAgentStartStats, RefinedAgentEndStats, BaseDecompUpdateBase +): + pass + + class InitialAnswerBASEUpdate(TypedDict): initial_base_answer: str @@ -45,12 +56,14 @@ class InitialAnswerUpdate(TypedDict): agent_base_metrics: AgentBaseMetrics -class RefinedAnswerUpdate(TypedDict): +class RefinedAnswerUpdateBase(TypedDict): refined_answer: str refined_agent_stats: RefinedAgentStats | None refined_answer_quality: bool - agent_refined_end_time: datetime - agent_refined_metrics: AgentRefinedMetrics + + +class RefinedAnswerUpdate(RefinedAgentEndStats, RefinedAnswerUpdateBase): + pass class InitialAnswerQualityUpdate(TypedDict): @@ -85,9 +98,14 @@ class EntityTermExtractionUpdate(TypedDict): entity_retlation_term_extractions: EntityRelationshipTermExtraction -class FollowUpSubQuestionsUpdate(TypedDict): +class FollowUpSubQuestionsUpdateBase(TypedDict): refined_sub_questions: dict[int, FollowUpSubQuestion] - agent_refined_start_time: datetime | None + + +class FollowUpSubQuestionsUpdate( + RefinedAgentStartStats, FollowUpSubQuestionsUpdateBase +): + pass ## Graph Input State @@ -104,7 +122,7 @@ class MainInput(CoreState): class MainState( # This includes the core state MainInput, - BaseDecompUpdate, + BaseDecompUpdateBase, InitialAnswerUpdate, InitialAnswerBASEUpdate, DecompAnswersUpdate, @@ -112,9 +130,11 @@ class MainState( EntityTermExtractionUpdate, InitialAnswerQualityUpdate, RequireRefinedAnswerUpdate, - FollowUpSubQuestionsUpdate, + FollowUpSubQuestionsUpdateBase, FollowUpDecompAnswersUpdate, - RefinedAnswerUpdate, + RefinedAnswerUpdateBase, + RefinedAgentStartStats, + RefinedAgentEndStats, ): # expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add] base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add] diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index 83d29fcc00..f8819c4a71 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -498,7 +498,7 @@ """ SUB_QUESTION_ANSWER_TEMPLATE = """ - Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + Sub-Question: Q{sub_question_nr}\n Sub-Question:\n - \n{sub_question}\n --\nAnswer:\n -\n {sub_answer}\n\n """ INITIAL_RAG_PROMPT = """ \n @@ -518,6 +518,11 @@ - If the information is empty or irrelevant, just say "I don't know". - If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why. +Remember to provide inline citations of documentsin the format [D1], [D2], [D3], etc., and [Q1], [Q2],... if +you want to cite the answer to a sub-question. If you have multiple citations, please cite for example +as [D1][Q3], or [D2][D4], etc. Feel free to cite documents in addition to the sub-questions! +Proper citations are important for the final answer to be verifiable! \n\n\n + Again, you should be sure that the answer is supported by the information provided! Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index 3e6011657d..756c7915e5 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -42,7 +42,12 @@ def normalize_whitespace(text: str) -> str: # Post-processing def format_docs(docs: Sequence[InferenceSection]) -> str: - return "\n\n".join(doc.combined_content for doc in docs) + formatted_doc_list = [] + + for doc_nr, doc in enumerate(docs): + formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}") + + return "\n\n".join(formatted_doc_list) def clean_and_parse_list_string(json_string: str) -> list[dict]: From 5788902d863c4a9007f0d15108da2e358ad8ee29 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Tue, 7 Jan 2025 15:44:02 -0800 Subject: [PATCH 65/78] tiny testing change --- backend/onyx/agent_search/shared_graph_utils/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index 756c7915e5..2a985d1d54 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -197,7 +197,8 @@ def get_test_config( config = ProSearchConfig( search_request=search_request, # chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), - chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), + # chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim + chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan message_id=1, use_persistence=True, ) From 7ce0436d71dba251ec62dfb80797ae8a5252878e Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Tue, 7 Jan 2025 16:39:26 -0800 Subject: [PATCH 66/78] main merge --- backend/alembic/versions/98a5008d8711_agent_tracking.py | 4 ++-- .../e9cf2bd7baed_create_pro_search_persistence_tables.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/alembic/versions/98a5008d8711_agent_tracking.py b/backend/alembic/versions/98a5008d8711_agent_tracking.py index 50e069c874..eb8b17e239 100644 --- a/backend/alembic/versions/98a5008d8711_agent_tracking.py +++ b/backend/alembic/versions/98a5008d8711_agent_tracking.py @@ -1,7 +1,7 @@ """agent_tracking Revision ID: 98a5008d8711 -Revises: 91a0a4d62b14 +Revises: 2955778aa44c Create Date: 2025-01-04 14:41:52.732238 """ @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "98a5008d8711" -down_revision = "91a0a4d62b14" +down_revision = "2955778aa44c" branch_labels = None depends_on = None diff --git a/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py b/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py index 1b56d0c91a..6ed9d783f5 100644 --- a/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py +++ b/backend/alembic/versions/e9cf2bd7baed_create_pro_search_persistence_tables.py @@ -1,7 +1,7 @@ """create pro search persistence tables Revision ID: e9cf2bd7baed -Revises: 91a0a4d62b14 +Revises: 98a5008d8711 Create Date: 2025-01-02 17:55:56.544246 """ From 79e7b73db1007a7cea007b341b60648f2e1a7e4b Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 8 Jan 2025 13:29:49 -0800 Subject: [PATCH 67/78] minor cleanup in preparation for Answer rework --- backend/onyx/chat/answer.py | 14 +++---- backend/onyx/chat/process_message.py | 36 ++++++++++------- .../tool_handling/tool_response_handler.py | 39 +++++++------------ backend/onyx/context/search/utils.py | 2 +- 4 files changed, 44 insertions(+), 47 deletions(-) diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index b7dbd3e06a..b4a21a5a64 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -28,6 +28,7 @@ from onyx.chat.stream_processing.utils import ( map_document_id_order, ) +from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM @@ -56,7 +57,6 @@ def __init__( # newly passed in files to include as part of this question # TODO THIS NEEDS TO BE HANDLED latest_query_files: list[InMemoryChatFile] | None = None, - files: list[InMemoryChatFile] | None = None, tools: list[Tool] | None = None, # NOTE: for native tool-calling, this is only supported by OpenAI atm, # but we only support them anyways @@ -79,7 +79,6 @@ def __init__( self.is_connected: Callable[[], bool] | None = is_connected self.latest_query_files = latest_query_files or [] - self.file_id_to_file = {file.file_id: file for file in (files or [])} self.tools = tools or [] self.force_use_tool = force_use_tool @@ -175,11 +174,7 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: current_llm_call.force_use_tool.tool_name, current_llm_call.force_use_tool.args, ) - tool = next( - (t for t in current_llm_call.tools if t.name == tool_name), None - ) - if not tool: - raise RuntimeError(f"Tool '{tool_name}' not found") + tool = get_tool_by_name(current_llm_call.tools, tool_name) yield from self._handle_specified_tool_call(llm_calls, tool, tool_args) return @@ -214,6 +209,11 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: current_llm_call ) or ([], []) + # NEXT: we still want to handle the LLM response stream, but it is now: + # 1. handle the tool call requests + # 2. feed back the processed results + # 3. handle the citations + # Quotes are no longer supported # answer_handler: AnswerResponseHandler # if self.answer_style_config.citation_config: diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 429a997cb3..ef8c906464 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -769,6 +769,7 @@ def stream_chat_message_objects( for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): + # TODO: don't need to dedupe here when we do it in agent flow if packet.id == SEARCH_RESPONSE_SUMMARY_ID: ( qa_docs_response, @@ -789,25 +790,30 @@ def stream_chat_message_objects( elif packet.id == SECTION_RELEVANCE_LIST_ID: relevance_sections = packet.response - if reference_db_search_docs is not None: - llm_indices = relevant_sections_to_indices( - relevance_sections=relevance_sections, - items=[ - translate_db_search_doc_to_server_search_doc(doc) - for doc in reference_db_search_docs - ], + if reference_db_search_docs is None: + logger.warning( + "No reference docs found for relevance filtering" ) + continue - if dropped_indices: - llm_indices = drop_llm_indices( - llm_indices=llm_indices, - search_docs=reference_db_search_docs, - dropped_indices=dropped_indices, - ) + llm_indices = relevant_sections_to_indices( + relevance_sections=relevance_sections, + items=[ + translate_db_search_doc_to_server_search_doc(doc) + for doc in reference_db_search_docs + ], + ) - yield LLMRelevanceFilterResponse( - llm_selected_doc_indices=llm_indices + if dropped_indices: + llm_indices = drop_llm_indices( + llm_indices=llm_indices, + search_docs=reference_db_search_docs, + dropped_indices=dropped_indices, ) + + yield LLMRelevanceFilterResponse( + llm_selected_doc_indices=llm_indices + ) elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: yield FinalUsedContextDocsResponse( final_context_docs=packet.response diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index 1a39e5c8d0..e188513062 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -25,6 +25,13 @@ logger = setup_logger() +def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool: + for tool in tools: + if tool.name == tool_name: + return tool + raise RuntimeError(f"Tool '{tool_name}' not found") + + class ToolResponseHandler: def __init__(self, tools: list[Tool]): self.tools = tools @@ -45,18 +52,7 @@ def get_tool_call_for_non_tool_calling_llm( ) -> tuple[Tool, dict] | None: if llm_call.force_use_tool.force_use: # if we are forcing a tool, we don't need to check which tools to run - tool = next( - ( - t - for t in llm_call.tools - if t.name == llm_call.force_use_tool.tool_name - ), - None, - ) - if not tool: - raise RuntimeError( - f"Tool '{llm_call.force_use_tool.tool_name}' not found" - ) + tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name) tool_args = ( llm_call.force_use_tool.args @@ -118,20 +114,17 @@ def _handle_tool_call(self) -> Generator[ResponsePart, None, None]: tool for tool in self.tools if tool.name == tool_call_request["name"] ] - if not known_tools_by_name: - logger.error( - "Tool call requested with unknown name field. \n" - f"self.tools: {self.tools}" - f"tool_call_request: {tool_call_request}" - ) - continue - else: + if known_tools_by_name: selected_tool = known_tools_by_name[0] selected_tool_call_request = tool_call_request - - if selected_tool and selected_tool_call_request: break + logger.error( + "Tool call requested with unknown name field. \n" + f"self.tools: {self.tools}" + f"tool_call_request: {tool_call_request}" + ) + if not selected_tool or not selected_tool_call_request: return @@ -171,8 +164,6 @@ def handle_response_part( else: self.tool_call_chunk += response_item # type: ignore - return - def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None: if ( self.tool_runner is None diff --git a/backend/onyx/context/search/utils.py b/backend/onyx/context/search/utils.py index 8a25ad1b78..4b42bb0808 100644 --- a/backend/onyx/context/search/utils.py +++ b/backend/onyx/context/search/utils.py @@ -80,7 +80,7 @@ def drop_llm_indices( search_docs: Sequence[DBSearchDoc | SavedSearchDoc], dropped_indices: list[int], ) -> list[int]: - llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))] + llm_bools = [i in llm_indices for i in range(len(search_docs))] if dropped_indices: llm_bools = [ val for ind, val in enumerate(llm_bools) if ind not in dropped_indices From 4c19b19488ce524474da4f85c5b960af3852b3dc Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Wed, 8 Jan 2025 15:54:57 -0800 Subject: [PATCH 68/78] initial agent citation_processing --- backend/onyx/agent_search/models.py | 7 + backend/onyx/agent_search/run_graph.py | 156 +++++++++++++++++- .../agent_search/shared_graph_utils/utils.py | 4 +- 3 files changed, 163 insertions(+), 4 deletions(-) create mode 100644 backend/onyx/agent_search/models.py diff --git a/backend/onyx/agent_search/models.py b/backend/onyx/agent_search/models.py new file mode 100644 index 0000000000..8a5d6ff2de --- /dev/null +++ b/backend/onyx/agent_search/models.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class AgentDocumentCitations(BaseModel): + document_id: str + document_title: str + link: str diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 3fe89e31f4..3c325d5c15 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,4 +1,5 @@ import asyncio +from collections import defaultdict from collections.abc import AsyncIterable from collections.abc import Iterable from datetime import datetime @@ -10,9 +11,11 @@ from onyx.agent_search.main.graph_builder import main_graph_builder from onyx.agent_search.main.states import MainInput +from onyx.agent_search.models import AgentDocumentCitations from onyx.agent_search.shared_graph_utils.utils import get_test_config from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStream +from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import ProSearchConfig from onyx.chat.models import SubAnswerPiece @@ -31,6 +34,19 @@ _COMPILED_GRAPH: CompiledStateGraph | None = None +def _set_combined_token_value( + combined_token: str, parsed_object: SubAnswerPiece | OnyxAnswerPiece +) -> SubAnswerPiece | OnyxAnswerPiece: + if isinstance(parsed_object, SubAnswerPiece): + parsed_object.sub_answer = combined_token + elif isinstance(parsed_object, OnyxAnswerPiece): + parsed_object.answer_piece = combined_token + else: + raise ValueError("Invalid parsed object type to update yielded token.") + + return parsed_object + + def _parse_agent_event( event: StreamEvent, ) -> AnswerPacket | None: @@ -107,11 +123,147 @@ def run_graph( db_session=db_session, search_tool=search_tool, ) + + agent_document_citations: dict[int, dict[int, list[AgentDocumentCitations]]] = {} + agent_question_citations_used_docs: defaultdict[ + int, defaultdict[int, list[str]] + ] = defaultdict(lambda: defaultdict(list)) + + # def _process_citation(current_yield_str: str) -> tuple[str, str]: + # """Process a citation string and return the formatted citation and remaining text.""" + # section_split = current_yield_str.split(']', 1) + # citation_part = section_split[0] + ']' + # remaining_text = section_split[1] if len(section_split) > 1 else '' + + # if 'D' in citation_part: + # citation_type = "Document" + # formatted_citation = citation_part.replace('[D', '[[').replace(']', ']]') + # else: # Q case + # citation_type = "Question" + # formatted_citation = citation_part.replace('[Q', '{{').replace(']', '}}') + + # return f" --- CITATION: {citation_type} - {formatted_citation}", remaining_text + + citation_potential = False + # leading_space = False + current_yield_components = [] + current_yield_str = "" + for event in _manage_async_event_streaming( compiled_graph=compiled_graph, graph_input=input ): - if parsed_object := _parse_agent_event(event): - yield parsed_object + parsed_object = _parse_agent_event(event) + + if parsed_object: + # if isinstance(parsed_object, SubAnswerPiece): + # logger.info(f"SA {parsed_object.sub_answer}") + + # token = parsed_object.sub_answer + + if isinstance(parsed_object, OnyxAnswerPiece) or isinstance( + parsed_object, SubAnswerPiece + ): + # logger.info(f"FA {parsed_object.answer_piece}") + + if isinstance(parsed_object, SubAnswerPiece): + token: str | None = parsed_object.sub_answer + elif isinstance(parsed_object, OnyxAnswerPiece): + token = parsed_object.answer_piece + if not token: + return parsed_object + else: + raise ValueError( + f"Invalid parsed object type: {type(parsed_object)}" + ) + + if not citation_potential and token: + if token.startswith(" ["): + citation_potential = True + current_yield_components = [token] + else: + yield parsed_object + elif token and citation_potential: + current_yield_components.append(token) + current_yield_str = "".join(current_yield_components) + + if current_yield_str.strip().startswith( + "[D" + ) or current_yield_str.strip().startswith("[Q"): + citation_potential = True + + else: + citation_potential = False + parsed_object = _set_combined_token_value( + current_yield_str, parsed_object + ) + yield parsed_object + + if len(current_yield_components) > 15: + citation_potential = False + parsed_object = _set_combined_token_value( + current_yield_str, parsed_object + ) + yield parsed_object + elif "]" in current_yield_str: + section_split = current_yield_str.split("]") + section_split[0] + "]" + start_of_next_section = "]".join(section_split[1:]) + citation_string = current_yield_str[ + : -len(start_of_next_section) + ] + if "[D" in citation_string: + citation_string = citation_string.replace( + "[D", "[[" + ).replace("]", "]]") + elif "[Q" in citation_string: + citation_string = citation_string.replace( + "[Q", "{{" + ).replace("]", "}}") + else: + pass + + parsed_object = _set_combined_token_value( + citation_string, parsed_object + ) + yield parsed_object + + current_yield_components = [start_of_next_section] + if not start_of_next_section.strip().startswith("["): + citation_potential = False + + elif isinstance(parsed_object, ExtendedToolResponse): + if parsed_object.id == "search_response_summary": + level = parsed_object.level + level_question_nr = parsed_object.level_question_nr + for inference_section in parsed_object.response.top_sections: + doc_link = inference_section.center_chunk.source_links[0] + doc_title = inference_section.center_chunk.title + doc_id = inference_section.center_chunk.document_id + + if ( + doc_id + not in agent_question_citations_used_docs[level][ + level_question_nr + ] + ): + if level not in agent_document_citations: + agent_document_citations[level] = {} + if level_question_nr not in agent_document_citations[level]: + agent_document_citations[level][level_question_nr] = [] + + agent_document_citations[level][level_question_nr].append( + AgentDocumentCitations( + document_id=doc_id, + document_title=doc_title, + link=doc_link, + ) + ) + agent_question_citations_used_docs[level][ + level_question_nr + ].append(doc_id) + else: + citation_potential = False + yield parsed_object # TODO: call this once on startup, TBD where and if it should be gated based diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index 2a985d1d54..e995b829d9 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -197,8 +197,8 @@ def get_test_config( config = ProSearchConfig( search_request=search_request, # chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), - # chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim - chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan + chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim + # chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan message_id=1, use_persistence=True, ) From 0bc3bb5558d4d70bc6524d7ba557fcb4ad672a96 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 8 Jan 2025 16:16:33 -0800 Subject: [PATCH 69/78] more minor cleanup --- backend/onyx/chat/llm_response_handler.py | 8 ++++++++ .../tools/tool_implementations/search_like_tool_utils.py | 6 +----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/backend/onyx/chat/llm_response_handler.py b/backend/onyx/chat/llm_response_handler.py index 612ce5dd56..342ea4f790 100644 --- a/backend/onyx/chat/llm_response_handler.py +++ b/backend/onyx/chat/llm_response_handler.py @@ -13,6 +13,14 @@ class LLMResponseHandlerManager: + """ + This class is responsible for postprocessing the LLM response stream. + In particular, we: + 1. handle the tool call requests + 2. handle citations + 3. pass through answers generated by the LLM + """ + def __init__( self, tool_handler: ToolResponseHandler, diff --git a/backend/onyx/tools/tool_implementations/search_like_tool_utils.py b/backend/onyx/tools/tool_implementations/search_like_tool_utils.py index 44dc8f2c33..49c4b7f4e2 100644 --- a/backend/onyx/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/onyx/tools/tool_implementations/search_like_tool_utils.py @@ -49,11 +49,7 @@ def build_next_prompt_for_search_like_tool( message=prompt_builder.user_message_and_token_cnt[0], prompt_config=prompt_config, context_docs=final_context_documents, - all_doc_useful=( - answer_style_config.citation_config.all_docs_useful - if answer_style_config.citation_config - else False - ), + all_doc_useful=(answer_style_config.citation_config.all_docs_useful), history_message=prompt_builder.single_message_history or "", ) ) From 0d059cf8351573080b6179f708643d25ae8f1d8b Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Wed, 8 Jan 2025 17:13:31 -0800 Subject: [PATCH 70/78] small fix --- backend/onyx/agent_search/run_graph.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 3c325d5c15..ed7a9bf657 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -170,7 +170,7 @@ def run_graph( elif isinstance(parsed_object, OnyxAnswerPiece): token = parsed_object.answer_piece if not token: - return parsed_object + yield parsed_object else: raise ValueError( f"Invalid parsed object type: {type(parsed_object)}" @@ -261,6 +261,9 @@ def run_graph( agent_question_citations_used_docs[level][ level_question_nr ].append(doc_id) + + yield parsed_object + else: citation_potential = False yield parsed_object From 8045d52090702a52a712476b4b2c910f03a9e7f7 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 8 Jan 2025 17:15:12 -0800 Subject: [PATCH 71/78] minor bug fix --- backend/onyx/agent_search/expanded_retrieval/nodes.py | 3 +++ backend/onyx/tools/tool_implementations/search/search_tool.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index f1eb798d4c..73c11f9c1e 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -122,16 +122,19 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: retrieved_documents=[], ) + # new db session to avoid concurrency issues with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=query_to_retrieve, force_no_rerank=True, alternate_db_session=db_session, ): + # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: retrieved_docs = cast( list[InferenceSection], tool_response.response.top_sections ) + level, question_nr = ( parse_question_id(state["sub_question_id"]) if state["sub_question_id"] diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 323c635775..eac9bc9eef 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -297,7 +297,7 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: search_request=SearchRequest( query=query, evaluation_type=LLMEvaluationType.SKIP - if force_no_rerank == "True" + if force_no_rerank else self.evaluation_type, human_selected_filters=( self.retrieval_options.filters if self.retrieval_options else None From 6ff439c34210653f3dcd0a6968e0502870145134 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Wed, 8 Jan 2025 20:22:29 -0800 Subject: [PATCH 72/78] Added link for document citations - assumes that Sub-question answers will follow the same citation format - assumes that we will be able to fit initial and revised answers into the same format. --- backend/onyx/agent_search/run_graph.py | 50 ++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index ed7a9bf657..c2fe34990e 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -153,6 +153,18 @@ def run_graph( compiled_graph=compiled_graph, graph_input=input ): parsed_object = _parse_agent_event(event) + if not parsed_object: + continue + + if hasattr(parsed_object, "level"): + level = parsed_object.level + else: + level = None + + if hasattr(parsed_object, "level_question_nr"): + level_question_nr = parsed_object.level_question_nr + else: + level_question_nr = None if parsed_object: # if isinstance(parsed_object, SubAnswerPiece): @@ -212,19 +224,45 @@ def run_graph( : -len(start_of_next_section) ] if "[D" in citation_string: - citation_string = citation_string.replace( - "[D", "[[" - ).replace("]", "]]") + cite_open_bracket_marker, cite_close_bracket_marker = ( + "[", + "]", + ) + cite_identifyer = "D" + + try: + cited_document = int(citation_string[2:-1]) + if level and level_question_nr: + link = agent_document_citations[int(level)][ + int(level_question_nr) + ][cited_document].link + else: + link = "" + except (ValueError, IndexError): + link = "" elif "[Q" in citation_string: - citation_string = citation_string.replace( - "[Q", "{{" - ).replace("]", "}}") + cite_open_bracket_marker, cite_close_bracket_marker = ( + "{", + "}", + ) + cite_identifyer = "Q" else: pass + citation_string = citation_string.replace( + "[" + cite_identifyer, + cite_open_bracket_marker + cite_open_bracket_marker, + ).replace( + "]", cite_close_bracket_marker + cite_close_bracket_marker + ) + + if cite_identifyer == "D": + citation_string += f"({link})" + parsed_object = _set_combined_token_value( citation_string, parsed_object ) + yield parsed_object current_yield_components = [start_of_next_section] From df75a3115be13ba79434f1ae4561eab4f535d0d6 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 9 Jan 2025 09:53:18 -0800 Subject: [PATCH 73/78] yield docs after deduping and reranking, yield initial question docs --- .../agent_search/expanded_retrieval/models.py | 2 + .../agent_search/expanded_retrieval/nodes.py | 58 +++++--- backend/onyx/agent_search/main/nodes.py | 28 ++++ backend/onyx/context/search/pipeline.py | 18 ++- backend/onyx/tools/models.py | 9 ++ .../search/search_tool.py | 126 +++++++++++------- 6 files changed, 170 insertions(+), 71 deletions(-) diff --git a/backend/onyx/agent_search/expanded_retrieval/models.py b/backend/onyx/agent_search/expanded_retrieval/models.py index 4e8caa3605..0e74135153 100644 --- a/backend/onyx/agent_search/expanded_retrieval/models.py +++ b/backend/onyx/agent_search/expanded_retrieval/models.py @@ -3,6 +3,7 @@ from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.context.search.models import InferenceSection +from onyx.tools.models import SearchQueryInfo ### Models ### @@ -11,6 +12,7 @@ class QueryResult(BaseModel): query: str search_results: list[InferenceSection] stats: RetrievalFitStats | None + query_info: SearchQueryInfo | None class ExpandedRetrievalResult(BaseModel): diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 73c11f9c1e..e453f692f4 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -41,9 +41,12 @@ from onyx.context.search.postprocessing.postprocessing import rerank_sections from onyx.db.engine import get_session_context_manager from onyx.llm.interfaces import LLM +from onyx.tools.models import SearchQueryInfo from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary +from onyx.tools.tool_implementations.search.search_tool import yield_search_responses from onyx.utils.logger import setup_logger logger = setup_logger() @@ -122,6 +125,7 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: retrieved_documents=[], ) + query_info = None # new db session to avoid concurrency issues with get_session_context_manager() as db_session: for tool_response in search_tool.run( @@ -131,25 +135,14 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: - retrieved_docs = cast( - list[InferenceSection], tool_response.response.top_sections + response = cast(SearchResponseSummary, tool_response.response) + retrieved_docs = response.top_sections + query_info = SearchQueryInfo( + predicted_search=response.predicted_search, + final_filters=response.final_filters, + recency_bias_multiplier=response.recency_bias_multiplier, ) - level, question_nr = ( - parse_question_id(state["sub_question_id"]) - if state["sub_question_id"] - else (0, 0) - ) - dispatch_custom_event( - "tool_response", - ExtendedToolResponse( - id=tool_response.id, - response=tool_response.response, - level=level, - level_question_nr=question_nr, - ), - ) - retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] pre_rerank_docs = retrieved_docs if search_tool.search_pipeline is not None: @@ -169,6 +162,7 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: query=query_to_retrieve, search_results=retrieved_docs, stats=fit_scores, + query_info=query_info, ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], @@ -353,6 +347,36 @@ def _calculate_sub_question_retrieval_stats( def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalUpdate: + level, question_nr = parse_question_id(state.get("sub_question_id") or "0_0") + query_infos = [ + result.query_info + for result in state["expanded_retrieval_results"] + if result.query_info is not None + ] + if len(query_infos) == 0: + raise ValueError("No query info found") + + # main question docs will be sent later after aggregation and deduping with sub-question docs + if not (level == 0 and question_nr == 0): + for tool_response in yield_search_responses( + query=state["question"], + reranked_sections=state[ + "retrieved_documents" + ], # TODO: rename params. this one is supposed to be the sections pre-merging + final_context_sections=state["reranked_documents"], + search_query_info=query_infos[0], # TODO: handle differing query infos? + get_section_relevance=lambda: None, # TODO: add relevance + search_tool=state["subgraph_search_tool"], + ): + dispatch_custom_event( + "tool_response", + ExtendedToolResponse( + id=tool_response.id, + response=tool_response.response, + level=level, + level_question_nr=question_nr, + ), + ) sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats( verified_documents=state["verified_documents"], expanded_retrieval_results=state["expanded_retrieval_results"], diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 461b14f074..2dd5aff197 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -62,9 +62,11 @@ from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agent_search.shared_graph_utils.utils import make_question_id from onyx.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import SubQuestionPiece from onyx.db.chat import log_agent_metrics from onyx.db.chat import log_agent_sub_question_results +from onyx.tools.tool_implementations.search.search_tool import yield_search_responses from onyx.utils.logger import setup_logger logger = setup_logger() @@ -253,6 +255,32 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: sub_question_docs, all_original_question_documents ) + # Use the query info from the base document retrieval + query_infos = [ + result.query_info + for result in state["original_question_retrieval_results"] + if result.query_info is not None + ] + if len(query_infos) == 0: + raise ValueError("No query info found") + for tool_response in yield_search_responses( + query=question, + reranked_sections=relevant_docs, + final_context_sections=relevant_docs, + search_query_info=query_infos[0], + get_section_relevance=lambda: None, # TODO: add relevance + search_tool=state["search_tool"], + ): + dispatch_custom_event( + "tool_response", + ExtendedToolResponse( + id=tool_response.id, + response=tool_response.response, + level=0, + level_question_nr=0, # 0, 0 is the base question + ), + ) + net_new_original_question_docs = [] for all_original_question_doc in all_original_question_documents: if all_original_question_doc not in sub_question_docs: diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index c6f8d8cbea..727e29e7f7 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -406,8 +406,18 @@ def section_relevance(self) -> list[SectionRelevancePiece] | None: @property def section_relevance_list(self) -> list[bool]: - llm_indices = relevant_sections_to_indices( - relevance_sections=self.section_relevance, - items=self.final_context_sections, + return section_relevance_list_impl( + section_relevance=self.section_relevance, + final_context_sections=self.final_context_sections, ) - return [ind in llm_indices for ind in range(len(self.final_context_sections))] + + +def section_relevance_list_impl( + section_relevance: list[SectionRelevancePiece] | None, + final_context_sections: list[InferenceSection], +) -> list[bool]: + llm_indices = relevant_sections_to_indices( + relevance_sections=section_relevance, + items=final_context_sections, + ) + return [ind in llm_indices for ind in range(len(final_context_sections))] diff --git a/backend/onyx/tools/models.py b/backend/onyx/tools/models.py index 4f56aecd37..b20289073e 100644 --- a/backend/onyx/tools/models.py +++ b/backend/onyx/tools/models.py @@ -4,6 +4,9 @@ from pydantic import BaseModel from pydantic import model_validator +from onyx.context.search.enums import SearchType +from onyx.context.search.models import IndexFilters + class ToolResponse(BaseModel): id: str | None = None @@ -45,5 +48,11 @@ class DynamicSchemaInfo(BaseModel): message_id: int | None +class SearchQueryInfo(BaseModel): + predicted_search: SearchType | None + final_filters: IndexFilters + recency_bias_multiplier: float + + CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID" diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index eac9bc9eef..0940e9af8d 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -1,9 +1,9 @@ import json +from collections.abc import Callable from collections.abc import Generator from typing import Any from typing import cast -from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.chat.chat_utils import llm_doc_from_inference_section @@ -25,13 +25,13 @@ from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import QueryFlow -from onyx.context.search.enums import SearchType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceSection from onyx.context.search.models import RerankingDetails from onyx.context.search.models import RetrievalDetails from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline +from onyx.context.search.pipeline import section_relevance_list_impl from onyx.db.models import Persona from onyx.db.models import User from onyx.llm.interfaces import LLM @@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase from onyx.tools.message import ToolCallSummary +from onyx.tools.models import SearchQueryInfo from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict @@ -62,13 +63,10 @@ SEARCH_EVALUATION_ID = "llm_doc_eval" -class SearchResponseSummary(BaseModel): +class SearchResponseSummary(SearchQueryInfo): top_sections: list[InferenceSection] rephrased_query: str | None = None predicted_flow: QueryFlow | None - predicted_search: SearchType | None - final_filters: IndexFilters - recency_bias_multiplier: float SEARCH_TOOL_DESCRIPTION = """ @@ -335,53 +333,20 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: ) self.search_pipeline = search_pipeline # used for agent_search metrics - yield ToolResponse( - id=SEARCH_RESPONSE_SUMMARY_ID, - response=SearchResponseSummary( - rephrased_query=query, - top_sections=search_pipeline.final_context_sections, - predicted_flow=search_pipeline.predicted_flow, - predicted_search=search_pipeline.predicted_search_type, - final_filters=search_pipeline.search_query.filters, - recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, - ), - ) - - yield ToolResponse( - id=SEARCH_DOC_CONTENT_ID, - response=OnyxContexts( - contexts=[ - OnyxContext( - content=section.combined_content, - document_id=section.center_chunk.document_id, - semantic_identifier=section.center_chunk.semantic_identifier, - blurb=section.center_chunk.blurb, - ) - for section in search_pipeline.reranked_sections - ] - ), - ) - - yield ToolResponse( - id=SECTION_RELEVANCE_LIST_ID, - response=search_pipeline.section_relevance, + search_query_info = SearchQueryInfo( + predicted_search=search_pipeline.search_query.search_type, + final_filters=search_pipeline.search_query.filters, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) - - pruned_sections = prune_sections( - sections=search_pipeline.final_context_sections, - section_relevance_list=search_pipeline.section_relevance_list, - prompt_config=self.prompt_config, - llm_config=self.llm.config, - question=query, - contextual_pruning_config=self.contextual_pruning_config, + yield from yield_search_responses( + query, + search_pipeline.reranked_sections, + search_pipeline.final_context_sections, + search_query_info, + lambda: search_pipeline.section_relevance, + self, ) - llm_docs = [ - llm_doc_from_inference_section(section) for section in pruned_sections - ] - - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) - def final_result(self, *args: ToolResponse) -> JSON_ro: final_docs = cast( list[LlmDoc], @@ -442,3 +407,64 @@ def get_search_result( initial_search_results = cast(list[LlmDoc], initial_search_results) return final_search_results, initial_search_results + + +# Allows yielding the same responses as a SearchTool without being a SearchTool. +# SearchTool passed in to allow for access to SearchTool properties. +# We can't just call SearchTool methods in the graph because we're operating on +# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run. +def yield_search_responses( + query: str, + reranked_sections: list[InferenceSection], + final_context_sections: list[InferenceSection], + search_query_info: SearchQueryInfo, + get_section_relevance: Callable[[], list[SectionRelevancePiece] | None], + search_tool: SearchTool, +) -> Generator[ToolResponse, None, None]: + yield ToolResponse( + id=SEARCH_RESPONSE_SUMMARY_ID, + response=SearchResponseSummary( + rephrased_query=query, + top_sections=final_context_sections, + predicted_flow=QueryFlow.QUESTION_ANSWER, + predicted_search=search_query_info.predicted_search, + final_filters=search_query_info.final_filters, + recency_bias_multiplier=search_query_info.recency_bias_multiplier, + ), + ) + + yield ToolResponse( + id=SEARCH_DOC_CONTENT_ID, + response=OnyxContexts( + contexts=[ + OnyxContext( + content=section.combined_content, + document_id=section.center_chunk.document_id, + semantic_identifier=section.center_chunk.semantic_identifier, + blurb=section.center_chunk.blurb, + ) + for section in reranked_sections + ] + ), + ) + + section_relevance = get_section_relevance() + yield ToolResponse( + id=SECTION_RELEVANCE_LIST_ID, + response=section_relevance, + ) + + pruned_sections = prune_sections( + sections=final_context_sections, + section_relevance_list=section_relevance_list_impl( + section_relevance, final_context_sections + ), + prompt_config=search_tool.prompt_config, + llm_config=search_tool.llm.config, + question=query, + contextual_pruning_config=search_tool.contextual_pruning_config, + ) + + llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections] + + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) From 686eddfa52671df479dcb1af59bb75d1feaf48ec Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 9 Jan 2025 10:05:49 -0800 Subject: [PATCH 74/78] code cleanup --- backend/onyx/agent_search/main/nodes.py | 45 ++++++++----------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index 2dd5aff197..eabfc1fddb 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -952,35 +952,23 @@ def agent_logging(state: MainState) -> MainOutput: agent_base_end_time = state["agent_base_end_time"] agent_refined_start_time = state["agent_refined_start_time"] or None agent_refined_end_time = state["agent_refined_end_time"] or None - if agent_refined_end_time is not None: - agent_end_time = agent_refined_end_time - else: - agent_end_time = agent_base_end_time + agent_end_time = agent_refined_end_time or agent_base_end_time + agent_base_duration = None if agent_base_end_time: agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds() - else: - agent_base_duration = None - - if agent_refined_end_time: - if agent_refined_start_time and agent_refined_end_time: - agent_refined_duration = ( - agent_refined_end_time - agent_refined_start_time - ).total_seconds() - else: - agent_refined_duration = None - else: - agent_refined_duration = None + agent_refined_duration = None + if agent_refined_start_time and agent_refined_end_time: + agent_refined_duration = ( + agent_refined_end_time - agent_refined_start_time + ).total_seconds() + + agent_full_duration = None if agent_end_time: agent_full_duration = (agent_end_time - agent_start_time).total_seconds() - else: - agent_full_duration = None - if agent_refined_duration: - agent_type = "refined" - else: - agent_type = "base" + agent_type = "refined" if agent_refined_duration else "base" agent_base_metrics = state["agent_base_metrics"] agent_refined_metrics = state["agent_refined_metrics"] @@ -996,18 +984,13 @@ def agent_logging(state: MainState) -> MainOutput: additional_metrics=AgentAdditionalMetrics(), ) + persona_id = None if state["config"].search_request.persona: persona_id = state["config"].search_request.persona.id - else: - persona_id = None - if "user" in state: - if state["user"]: - user_id = state["user"].id - else: - user_id = None - else: - user_id = None + user_id = None + if "user" in state and state["user"]: + user_id = state["user"].id # log the agent metrics log_agent_metrics( From 685f12c5313a521f13ce05ae23770b9c22c9052f Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 9 Jan 2025 12:38:26 -0800 Subject: [PATCH 75/78] sub-question citation processing --- backend/onyx/agent_search/run_graph.py | 71 ++++++++++++------- .../shared_graph_utils/agent_prompt_ops.py | 2 +- .../shared_graph_utils/prompts.py | 2 +- 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index c2fe34990e..d79903a9fe 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -129,6 +129,17 @@ def run_graph( int, defaultdict[int, list[str]] ] = defaultdict(lambda: defaultdict(list)) + citation_potential: defaultdict[int, defaultdict[int, bool]] = defaultdict( + lambda: defaultdict(lambda: False) + ) + + current_yield_components: defaultdict[ + int, defaultdict[int, list[str]] + ] = defaultdict(lambda: defaultdict(list)) + current_yield_str: defaultdict[int, defaultdict[int, str]] = defaultdict( + lambda: defaultdict(lambda: "") + ) + # def _process_citation(current_yield_str: str) -> tuple[str, str]: # """Process a citation string and return the formatted citation and remaining text.""" # section_split = current_yield_str.split(']', 1) @@ -144,11 +155,6 @@ def run_graph( # return f" --- CITATION: {citation_type} - {formatted_citation}", remaining_text - citation_potential = False - # leading_space = False - current_yield_components = [] - current_yield_str = "" - for event in _manage_async_event_streaming( compiled_graph=compiled_graph, graph_input=input ): @@ -179,8 +185,12 @@ def run_graph( if isinstance(parsed_object, SubAnswerPiece): token: str | None = parsed_object.sub_answer + level = parsed_object.level + level_question_nr = parsed_object.level_question_nr elif isinstance(parsed_object, OnyxAnswerPiece): token = parsed_object.answer_piece + level = 0 + level_question_nr = 0 if not token: yield parsed_object else: @@ -188,39 +198,45 @@ def run_graph( f"Invalid parsed object type: {type(parsed_object)}" ) - if not citation_potential and token: + if not citation_potential[level][level_question_nr] and token: if token.startswith(" ["): - citation_potential = True - current_yield_components = [token] + citation_potential[level][level_question_nr] = True + current_yield_components[level][level_question_nr] = [token] else: yield parsed_object - elif token and citation_potential: - current_yield_components.append(token) - current_yield_str = "".join(current_yield_components) + elif token and citation_potential[level][level_question_nr]: + current_yield_components[level][level_question_nr].append(token) + current_yield_str[level][level_question_nr] = "".join( + current_yield_components[level][level_question_nr] + ) - if current_yield_str.strip().startswith( + if current_yield_str[level][level_question_nr].strip().startswith( "[D" - ) or current_yield_str.strip().startswith("[Q"): - citation_potential = True + ) or current_yield_str[level][level_question_nr].strip().startswith( + "[Q" + ): + citation_potential[level][level_question_nr] = True else: - citation_potential = False + citation_potential[level][level_question_nr] = False parsed_object = _set_combined_token_value( - current_yield_str, parsed_object + current_yield_str[level][level_question_nr], parsed_object ) yield parsed_object - if len(current_yield_components) > 15: - citation_potential = False + if len(current_yield_components[level][level_question_nr]) > 15: + citation_potential[level][level_question_nr] = False parsed_object = _set_combined_token_value( - current_yield_str, parsed_object + current_yield_str[level][level_question_nr], parsed_object ) yield parsed_object - elif "]" in current_yield_str: - section_split = current_yield_str.split("]") + elif "]" in current_yield_str[level][level_question_nr]: + section_split = current_yield_str[level][ + level_question_nr + ].split("]") section_split[0] + "]" start_of_next_section = "]".join(section_split[1:]) - citation_string = current_yield_str[ + citation_string = current_yield_str[level][level_question_nr][ : -len(start_of_next_section) ] if "[D" in citation_string: @@ -231,7 +247,9 @@ def run_graph( cite_identifyer = "D" try: - cited_document = int(citation_string[2:-1]) + cited_document = int( + citation_string[level][level_question_nr][2:-1] + ) if level and level_question_nr: link = agent_document_citations[int(level)][ int(level_question_nr) @@ -265,9 +283,11 @@ def run_graph( yield parsed_object - current_yield_components = [start_of_next_section] + current_yield_components[level][level_question_nr] = [ + start_of_next_section + ] if not start_of_next_section.strip().startswith("["): - citation_potential = False + citation_potential[level][level_question_nr] = False elif isinstance(parsed_object, ExtendedToolResponse): if parsed_object.id == "search_response_summary": @@ -303,7 +323,6 @@ def run_graph( yield parsed_object else: - citation_potential = False yield parsed_object diff --git a/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py b/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py index 4f0cf106b4..9c43cb247a 100644 --- a/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py +++ b/backend/onyx/agent_search/shared_graph_utils/agent_prompt_ops.py @@ -18,7 +18,7 @@ def build_sub_question_answer_prompt( ) docs_format_list = [ - f"""Document Number: [{doc_nr + 1}]\n + f"""Document Number: [D{doc_nr + 1}]\n Content: {doc.combined_content}\n\n""" for doc_nr, doc in enumerate(docs) ] diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index f8819c4a71..ff42875b5c 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -56,7 +56,7 @@ Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. (But keep other details as well.) - Remember to provide inline citations in the format [1], [2], [3], etc.\n\n\n + Please remember to provide inline citations in the format [D1], [D2], [D3], etc.\n\n\n For your general information, here is the ultimate motivation: \n--\n {original_question} \n--\n From d379453d6431736056bd86399c7681b32c5e3c02 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 9 Jan 2025 13:33:24 -0800 Subject: [PATCH 76/78] Creation of AgentAnswerPiece class - unifying sub-answers and main agent answers --- .../nodes/answer_generation.py | 7 +-- backend/onyx/agent_search/main/nodes.py | 19 ++++++-- backend/onyx/agent_search/run_graph.py | 44 +++++++++---------- backend/onyx/chat/models.py | 8 ++-- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py index 9b4ac4871c..ad4f2eb2d6 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py @@ -13,7 +13,7 @@ from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agent_search.shared_graph_utils.utils import parse_question_id -from onyx.chat.models import SubAnswerPiece +from onyx.chat.models import AgentAnswerPiece from onyx.utils.logger import setup_logger logger = setup_logger() @@ -57,10 +57,11 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: ) dispatch_custom_event( "sub_answers", - SubAnswerPiece( - sub_answer=content, + AgentAnswerPiece( + answer_piece=content, level=level, level_question_nr=question_nr, + answer_type="agent_sub_answer", ), ) response.append(content) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index eabfc1fddb..f5b2061280 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -62,6 +62,7 @@ from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agent_search.shared_graph_utils.utils import make_question_id from onyx.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import SubQuestionPiece from onyx.db.chat import log_agent_metrics @@ -343,11 +344,23 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: model = state["fast_llm"] streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] for message in model.stream(msg): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) dispatch_custom_event( - "main_answer", - message.content, + "initial_agent_answer", + AgentAnswerPiece( + answer_piece=content, + level=0, + level_question_nr=0, + answer_type="agent_level_answer", + ), ) - streamed_tokens.append(message.content) + streamed_tokens.append(content) + response = merge_content(*streamed_tokens) answer = cast(str, response) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index d79903a9fe..c058ced28e 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -13,12 +13,12 @@ from onyx.agent_search.main.states import MainInput from onyx.agent_search.models import AgentDocumentCitations from onyx.agent_search.shared_graph_utils.utils import get_test_config +from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStream from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import ProSearchConfig -from onyx.chat.models import SubAnswerPiece from onyx.chat.models import SubQueryPiece from onyx.chat.models import SubQuestionPiece from onyx.chat.models import ToolResponse @@ -35,14 +35,9 @@ def _set_combined_token_value( - combined_token: str, parsed_object: SubAnswerPiece | OnyxAnswerPiece -) -> SubAnswerPiece | OnyxAnswerPiece: - if isinstance(parsed_object, SubAnswerPiece): - parsed_object.sub_answer = combined_token - elif isinstance(parsed_object, OnyxAnswerPiece): - parsed_object.answer_piece = combined_token - else: - raise ValueError("Invalid parsed object type to update yielded token.") + combined_token: str, parsed_object: AgentAnswerPiece +) -> AgentAnswerPiece: + parsed_object.answer_piece = combined_token return parsed_object @@ -64,9 +59,9 @@ def _parse_agent_event( elif event["name"] == "subqueries": return cast(SubQueryPiece, event["data"]) elif event["name"] == "sub_answers": - return cast(SubAnswerPiece, event["data"]) - elif event["name"] == "main_answer": - return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) + return cast(AgentAnswerPiece, event["data"]) + elif event["name"] == "initial_agent_answer": + return cast(AgentAnswerPiece, event["data"]) elif event["name"] == "tool_response": return cast(ToolResponse, event["data"]) return None @@ -179,20 +174,15 @@ def run_graph( # token = parsed_object.sub_answer if isinstance(parsed_object, OnyxAnswerPiece) or isinstance( - parsed_object, SubAnswerPiece + parsed_object, AgentAnswerPiece ): # logger.info(f"FA {parsed_object.answer_piece}") - if isinstance(parsed_object, SubAnswerPiece): - token: str | None = parsed_object.sub_answer + if isinstance(parsed_object, AgentAnswerPiece): + token: str | None = parsed_object.answer_piece level = parsed_object.level level_question_nr = parsed_object.level_question_nr - elif isinstance(parsed_object, OnyxAnswerPiece): - token = parsed_object.answer_piece - level = 0 - level_question_nr = 0 - if not token: - yield parsed_object + parsed_object.answer_type else: raise ValueError( f"Invalid parsed object type: {type(parsed_object)}" @@ -388,11 +378,17 @@ def run_main_graph( logger.info( f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | " ) - elif isinstance(output, SubAnswerPiece): + elif ( + isinstance(output, AgentAnswerPiece) + and output.answer_type == "agent_sub_answer" + ): logger.info( - f" ---- SA {output.level} - {output.level_question_nr} {output.sub_answer} | " + f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | " ) - elif isinstance(output, OnyxAnswerPiece): + elif ( + isinstance(output, AgentAnswerPiece) + and output.answer_type == "agent_level_answer" + ): logger.info(f" ---------- FA {output.answer_piece} | ") # for tool_response in tool_responses: diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 60a50089ab..23110387ea 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -3,6 +3,7 @@ from datetime import datetime from enum import Enum from typing import Any +from typing import Literal from typing import TYPE_CHECKING from uuid import UUID @@ -361,10 +362,11 @@ class SubQueryPiece(BaseModel): query_id: int -class SubAnswerPiece(BaseModel): - sub_answer: str +class AgentAnswerPiece(BaseModel): + answer_piece: str level: int level_question_nr: int + answer_type: Literal["agent_sub_answer", "agent_level_answer"] class SubQuestionPiece(BaseModel): @@ -379,7 +381,7 @@ class ExtendedToolResponse(ToolResponse): ProSearchPacket = ( - SubQuestionPiece | SubAnswerPiece | SubQueryPiece | ExtendedToolResponse + SubQuestionPiece | AgentAnswerPiece | SubQueryPiece | ExtendedToolResponse ) AnswerPacket = ( From aee625d525149a3f671a4e18b18fecac99eafd9e Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 9 Jan 2025 14:02:05 -0800 Subject: [PATCH 77/78] agent logging level change --- .../answer_initial_sub_question/edges.py | 2 +- .../graph_builder.py | 2 +- .../nodes/answer_generation.py | 4 +- .../answer_refinement_sub_question/edges.py | 2 +- .../graph_builder.py | 4 +- .../nodes/format_raw_search_results.py | 2 +- .../nodes/generate_raw_search_data.py | 2 +- .../expanded_retrieval/graph_builder.py | 2 +- .../onyx/agent_search/main/graph_builder.py | 2 +- backend/onyx/agent_search/main/nodes.py | 102 +++++++++--------- backend/onyx/agent_search/run_graph.py | 16 +-- .../shared_graph_utils/calculations.py | 2 +- 12 files changed, 73 insertions(+), 69 deletions(-) diff --git a/backend/onyx/agent_search/answer_initial_sub_question/edges.py b/backend/onyx/agent_search/answer_initial_sub_question/edges.py index 62cc9a93fa..7f7c7d034c 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/edges.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/edges.py @@ -11,7 +11,7 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: - logger.info("sending to expanded retrieval via edge") + logger.debug("sending to expanded retrieval via edge") return Send( "initial_sub_question_expanded_retrieval", diff --git a/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py b/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py index 58e914e020..cda4d03b49 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py @@ -120,4 +120,4 @@ def answer_query_graph_builder() -> StateGraph: # debug=True, # subgraphs=True, ): - logger.info(thing) + logger.debug(thing) diff --git a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py index ad4f2eb2d6..1f6dca61b4 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py @@ -21,7 +21,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: now_start = datetime.datetime.now() - logger.info(f"--------{now_start}--------START ANSWER GENERATION---") + logger.debug(f"--------{now_start}--------START ANSWER GENERATION---") question = state["question"] docs = state["documents"] @@ -35,7 +35,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: persona_prompt=persona_prompt ) - logger.info(f"Number of verified retrieval docs: {len(docs)}") + logger.debug(f"Number of verified retrieval docs: {len(docs)}") msg = build_sub_question_answer_prompt( question=question, diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/edges.py b/backend/onyx/agent_search/answer_refinement_sub_question/edges.py index 41059136ea..5479d2c34f 100644 --- a/backend/onyx/agent_search/answer_refinement_sub_question/edges.py +++ b/backend/onyx/agent_search/answer_refinement_sub_question/edges.py @@ -11,7 +11,7 @@ def send_to_expanded_refined_retrieval(state: AnswerQuestionInput) -> Send | Hashable: - logger.info("sending to expanded retrieval for follow up question via edge") + logger.debug("sending to expanded retrieval for follow up question via edge") return Send( "refined_sub_question_expanded_retrieval", diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py b/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py index 15fe7156a4..85774bb600 100644 --- a/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py @@ -111,6 +111,6 @@ def answer_refined_query_graph_builder() -> StateGraph: # debug=True, # subgraphs=True, ): - logger.info(thing) + logger.debug(thing) # output = compiled_graph.invoke(inputs) - # logger.info(output) + # logger.debug(output) diff --git a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py index 6b39fd4fe6..4acda010ef 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py @@ -6,7 +6,7 @@ def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput: - logger.info("format_raw_search_results") + logger.debug("format_raw_search_results") return BaseRawSearchOutput( base_expanded_retrieval_result=state["expanded_retrieval_result"], # base_retrieval_results=[state["expanded_retrieval_result"]], diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py index cd9c003f47..0aff2a4f70 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py @@ -6,7 +6,7 @@ def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput: - logger.info("generate_raw_search_data") + logger.debug("generate_raw_search_data") return ExpandedRetrievalInput( subgraph_config=state["config"], subgraph_primary_llm=state["primary_llm"], diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index e5988c7bcb..5a225f1f94 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -115,4 +115,4 @@ def expanded_retrieval_graph_builder() -> StateGraph: # debug=True, subgraphs=True, ): - logger.info(thing) + logger.debug(thing) diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 5e45a7fffc..a3c33f1dd2 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -255,4 +255,4 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: # debug=True, subgraphs=True, ): - logger.info(thing) + logger.debug(thing) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index f5b2061280..6db39a5258 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -90,7 +90,7 @@ def _helper(sub_question_part: str, num: int) -> None: def initial_sub_question_creation(state: MainState) -> BaseDecompUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------BASE DECOMP START---") + logger.debug(f"--------{now_start}--------BASE DECOMP START---") question = state["config"].search_request.query state["db_session"] @@ -144,7 +144,7 @@ def initial_sub_question_creation(state: MainState) -> BaseDecompUpdate: now_end = datetime.now() - logger.info(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---") + logger.debug(f"--------{now_end}--{now_end - now_start}--------BASE DECOMP END---") return BaseDecompUpdate( initial_decomp_questions=decomp_list, @@ -245,7 +245,7 @@ def _calculate_initial_agent_stats( def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------GENERATE INITIAL---") + logger.debug(f"--------{now_start}--------GENERATE INITIAL---") question = state["config"].search_request.query persona_prompt = get_persona_prompt(state["config"].search_request.persona) @@ -368,16 +368,18 @@ def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: state["decomp_answer_results"], state["original_question_retrieval_stats"] ) - logger.info(f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n") + logger.debug( + f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n" + ) if initial_agent_stats: - logger.info(initial_agent_stats.original_question) - logger.info(initial_agent_stats.sub_questions) - logger.info(initial_agent_stats.agent_effectiveness) + logger.debug(initial_agent_stats.original_question) + logger.debug(initial_agent_stats.sub_questions) + logger.debug(initial_agent_stats.agent_effectiveness) now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n" ) @@ -428,7 +430,7 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate now_start = datetime.now() - logger.info( + logger.debug( f"--------{now_start}--------Checking for base answer validity - for not set True/False manually" ) @@ -436,7 +438,7 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INITIAL ANSWER QUALITY CHECK END---" ) @@ -446,7 +448,7 @@ def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate def entity_term_extraction_llm(state: MainState) -> EntityTermExtractionUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------GENERATE ENTITIES & TERMS---") + logger.debug(f"--------{now_start}--------GENERATE ENTITIES & TERMS---") if not state["config"].allow_refinement: return EntityTermExtractionUpdate( @@ -526,7 +528,7 @@ def entity_term_extraction_llm(state: MainState) -> EntityTermExtractionUpdate: now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------ENTITY TERM EXTRACTION END---" ) @@ -544,7 +546,7 @@ def generate_initial_base_search_only_answer( ) -> InitialAnswerBASEUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---") + logger.debug(f"--------{now_start}--------GENERATE INITIAL BASE ANSWER---") question = state["config"].search_request.query original_question_docs = state["all_original_question_documents"] @@ -565,7 +567,7 @@ def generate_initial_base_search_only_answer( now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INITIAL BASE ANSWER END---\n\n" ) @@ -577,7 +579,7 @@ def ingest_initial_sub_question_answers( ) -> DecompAnswersUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------INGEST ANSWERS---") + logger.debug(f"--------{now_start}--------INGEST ANSWERS---") documents = [] answer_results = state.get("answer_results", []) for answer_result in answer_results: @@ -585,7 +587,7 @@ def ingest_initial_sub_question_answers( now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INGEST ANSWERS END---" ) @@ -602,7 +604,7 @@ def ingest_initial_base_retrieval( ) -> ExpandedRetrievalUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---") + logger.debug(f"--------{now_start}--------INGEST INITIAL RETRIEVAL---") sub_question_retrieval_stats = state[ "base_expanded_retrieval_result" @@ -614,7 +616,7 @@ def ingest_initial_base_retrieval( now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INGEST INITIAL RETRIEVAL END---" ) @@ -632,11 +634,11 @@ def ingest_initial_base_retrieval( def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------REFINED ANSWER DECISION---") + logger.debug(f"--------{now_start}--------REFINED ANSWER DECISION---") now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER DECISION END---" ) @@ -650,7 +652,7 @@ def refined_answer_decision(state: MainState) -> RequireRefinedAnswerUpdate: def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------GENERATE REFINED ANSWER---") + logger.debug(f"--------{now_start}--------GENERATE REFINED ANSWER---") question = state["config"].search_request.query persona_prompt = get_persona_prompt(state["config"].search_request.persona) @@ -763,24 +765,26 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: revision_question_efficiency=revision_question_efficiency, ) - logger.info(f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}") - logger.info("-" * 10) - logger.info(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") + logger.debug( + f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}" + ) + logger.debug("-" * 10) + logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}") - logger.info("-" * 100) - logger.info(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n") - logger.info("-" * 10) - logger.info( + logger.debug("-" * 100) + logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n") + logger.debug("-" * 10) + logger.debug( f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n" ) - logger.info("-" * 100) + logger.debug("-" * 100) - logger.info( + logger.debug( f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n" ) - logger.info("-" * 100) + logger.debug("-" * 100) if state["initial_agent_stats"]: initial_doc_boost_factor = state["initial_agent_stats"].agent_effectiveness.get( @@ -799,29 +803,29 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: "initial_agent_stats" ].sub_questions.get("num_verified_documents", "--") - logger.info("INITIAL AGENT STATS") - logger.info(f"Document Boost Factor: {initial_doc_boost_factor}") - logger.info(f"Support Boost Factor: {initial_support_boost_factor}") - logger.info(f"Originally Verified Docs: {num_initial_verified_docs}") - logger.info( + logger.debug("INITIAL AGENT STATS") + logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}") + logger.debug(f"Support Boost Factor: {initial_support_boost_factor}") + logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}") + logger.debug( f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}" ) - logger.info( + logger.debug( f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}" ) if refined_agent_stats: - logger.info("-" * 10) - logger.info("REFINED AGENT STATS") - logger.info( + logger.debug("-" * 10) + logger.debug("REFINED AGENT STATS") + logger.debug( f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}" ) - logger.info( + logger.debug( f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}" ) now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INITIAL AGENT ANSWER END---\n\n" ) @@ -841,7 +845,7 @@ def generate_refined_answer(state: MainState) -> RefinedAnswerUpdate: now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------REFINED ANSWER UPDATE END---" ) @@ -859,7 +863,7 @@ def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdat now_start = datetime.now() - logger.info(f"--------{now_start}--------FOLLOW UP DECOMPOSE---") + logger.debug(f"--------{now_start}--------FOLLOW UP DECOMPOSE---") agent_refined_start_time = datetime.now() @@ -920,7 +924,7 @@ def refined_sub_question_creation(state: MainState) -> FollowUpSubQuestionsUpdat now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------FOLLOW UP DECOMPOSE END---" ) @@ -935,7 +939,7 @@ def ingest_refined_answers( ) -> DecompAnswersUpdate: now_start = datetime.now() - logger.info(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---") + logger.debug(f"--------{now_start}--------INGEST FOLLOW UP ANSWERS---") documents = [] answer_results = state.get("answer_results", []) @@ -944,7 +948,7 @@ def ingest_refined_answers( now_end = datetime.now() - logger.info( + logger.debug( f"--------{now_end}--{now_end - now_start}--------INGEST FOLLOW UP ANSWERS END---" ) @@ -959,7 +963,7 @@ def ingest_refined_answers( def agent_logging(state: MainState) -> MainOutput: now_start = datetime.now() - logger.info(f"--------{now_start}--------LOGGING NODE---") + logger.debug(f"--------{now_start}--------LOGGING NODE---") agent_start_time = state["agent_start_time"] agent_base_end_time = state["agent_base_end_time"] @@ -1043,6 +1047,6 @@ def agent_logging(state: MainState) -> MainOutput: now_end = datetime.now() - logger.info(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---") + logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---") return main_output diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index c058ced28e..39c2d4a2e9 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -169,14 +169,14 @@ def run_graph( if parsed_object: # if isinstance(parsed_object, SubAnswerPiece): - # logger.info(f"SA {parsed_object.sub_answer}") + # logger.debug(f"SA {parsed_object.sub_answer}") # token = parsed_object.sub_answer if isinstance(parsed_object, OnyxAnswerPiece) or isinstance( parsed_object, AgentAnswerPiece ): - # logger.info(f"FA {parsed_object.answer_piece}") + # logger.debug(f"FA {parsed_object.answer_piece}") if isinstance(parsed_object, AgentAnswerPiece): token: str | None = parsed_object.answer_piece @@ -343,12 +343,12 @@ def run_main_graph( from onyx.llm.factory import get_default_llms now_start = datetime.now() - logger.info(f"Start at {now_start}") + logger.debug(f"Start at {now_start}") graph = main_graph_builder() compiled_graph = graph.compile() now_end = datetime.now() - logger.info(f"Graph compiled in {now_end - now_start} seconds") + logger.debug(f"Graph compiled in {now_end - now_start} seconds") primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( # query="what can you do with gitlab?", @@ -375,21 +375,21 @@ def run_main_graph( elif isinstance(output, ToolResponse): tool_responses.append(output.response) elif isinstance(output, SubQuestionPiece): - logger.info( + logger.debug( f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | " ) elif ( isinstance(output, AgentAnswerPiece) and output.answer_type == "agent_sub_answer" ): - logger.info( + logger.debug( f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | " ) elif ( isinstance(output, AgentAnswerPiece) and output.answer_type == "agent_level_answer" ): - logger.info(f" ---------- FA {output.answer_piece} | ") + logger.debug(f" ---------- FA {output.answer_piece} | ") # for tool_response in tool_responses: - # logger.info(tool_response) + # logger.debug(tool_response) diff --git a/backend/onyx/agent_search/shared_graph_utils/calculations.py b/backend/onyx/agent_search/shared_graph_utils/calculations.py index b60441bed3..cd8a5885ae 100644 --- a/backend/onyx/agent_search/shared_graph_utils/calculations.py +++ b/backend/onyx/agent_search/shared_graph_utils/calculations.py @@ -52,7 +52,7 @@ def get_fit_scores( ) for rank_type, docs in ranked_sections.items(): - logger.info(f"rank_type: {rank_type}") + logger.debug(f"rank_type: {rank_type}") for i in [1, 5, 10]: fit_eval.fit_scores[rank_type].scores[str(i)] = ( From 4ab99fb4a785c2ddc1ec13d600b7fd72d2e30fe2 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 9 Jan 2025 15:50:54 -0800 Subject: [PATCH 78/78] created pro_search_a directory - moved all non-common files into that new directory to enable multiple graphs --- .../answer_initial_sub_question/edges.py | 8 ++- .../graph_builder.py | 24 +++++---- .../answer_initial_sub_question/models.py | 2 +- .../nodes/answer_check.py | 8 ++- .../nodes/answer_generation.py | 8 ++- .../nodes/format_answer.py | 12 +++-- .../nodes/ingest_retrieval.py | 6 ++- .../answer_initial_sub_question/states.py | 6 ++- .../answer_refinement_sub_question/edges.py | 8 ++- .../graph_builder.py | 24 +++++---- .../answer_refinement_sub_question/models.py | 0 .../base_raw_search/graph_builder.py | 12 ++--- .../base_raw_search/models.py | 2 +- .../nodes/format_raw_search_results.py | 6 ++- .../nodes/generate_raw_search_data.py | 4 +- .../base_raw_search/states.py | 4 +- .../expanded_retrieval/edges.py | 6 ++- .../expanded_retrieval/graph_builder.py | 28 ++++++---- .../expanded_retrieval/models.py | 0 .../expanded_retrieval/nodes.py | 38 ++++++++----- .../expanded_retrieval/states.py | 6 ++- .../{ => pro_search_a}/main/edges.py | 12 +++-- .../{ => pro_search_a}/main/graph_builder.py | 44 ++++++++------- .../{ => pro_search_a}/main/models.py | 0 .../{ => pro_search_a}/main/nodes.py | 54 ++++++++++--------- .../{ => pro_search_a}/main/states.py | 18 ++++--- backend/onyx/agent_search/run_graph.py | 4 +- .../shared_graph_utils/operators.py | 4 +- .../agent_search/shared_graph_utils/utils.py | 2 +- backend/onyx/db/chat.py | 6 ++- 30 files changed, 224 insertions(+), 132 deletions(-) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/edges.py (75%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/graph_builder.py (77%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/models.py (86%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/nodes/answer_check.py (77%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/nodes/answer_generation.py (92%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/nodes/format_answer.py (64%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/nodes/ingest_retrieval.py (79%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_initial_sub_question/states.py (89%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_refinement_sub_question/edges.py (76%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_refinement_sub_question/graph_builder.py (76%) rename backend/onyx/agent_search/{ => pro_search_a}/answer_refinement_sub_question/models.py (100%) rename backend/onyx/agent_search/{ => pro_search_a}/base_raw_search/graph_builder.py (73%) rename backend/onyx/agent_search/{ => pro_search_a}/base_raw_search/models.py (86%) rename backend/onyx/agent_search/{ => pro_search_a}/base_raw_search/nodes/format_raw_search_results.py (69%) rename backend/onyx/agent_search/{ => pro_search_a}/base_raw_search/nodes/generate_raw_search_data.py (87%) rename backend/onyx/agent_search/{ => pro_search_a}/base_raw_search/states.py (89%) rename backend/onyx/agent_search/{ => pro_search_a}/expanded_retrieval/edges.py (80%) rename backend/onyx/agent_search/{ => pro_search_a}/expanded_retrieval/graph_builder.py (75%) rename backend/onyx/agent_search/{ => pro_search_a}/expanded_retrieval/models.py (100%) rename backend/onyx/agent_search/{ => pro_search_a}/expanded_retrieval/nodes.py (92%) rename backend/onyx/agent_search/{ => pro_search_a}/expanded_retrieval/states.py (90%) rename backend/onyx/agent_search/{ => pro_search_a}/main/edges.py (88%) rename backend/onyx/agent_search/{ => pro_search_a}/main/graph_builder.py (80%) rename backend/onyx/agent_search/{ => pro_search_a}/main/models.py (100%) rename backend/onyx/agent_search/{ => pro_search_a}/main/nodes.py (94%) rename backend/onyx/agent_search/{ => pro_search_a}/main/states.py (85%) diff --git a/backend/onyx/agent_search/answer_initial_sub_question/edges.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/edges.py similarity index 75% rename from backend/onyx/agent_search/answer_initial_sub_question/edges.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/edges.py index 7f7c7d034c..9153b67a22 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/edges.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/edges.py @@ -2,9 +2,13 @@ from langgraph.types import Send -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput from onyx.agent_search.core_state import in_subgraph_extract_core_fields -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) from onyx.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/graph_builder.py similarity index 77% rename from backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/graph_builder.py index cda4d03b49..8e98ef82a0 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/graph_builder.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/graph_builder.py @@ -2,25 +2,31 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_initial_sub_question.edges import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.edges import ( send_to_expanded_retrieval, ) -from onyx.agent_search.answer_initial_sub_question.nodes.answer_check import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_check import ( answer_check, ) -from onyx.agent_search.answer_initial_sub_question.nodes.answer_generation import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_generation import ( answer_generation, ) -from onyx.agent_search.answer_initial_sub_question.nodes.format_answer import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.format_answer import ( format_answer, ) -from onyx.agent_search.answer_initial_sub_question.nodes.ingest_retrieval import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.ingest_retrieval import ( ingest_retrieval, ) -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState -from onyx.agent_search.expanded_retrieval.graph_builder import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) from onyx.agent_search.shared_graph_utils.utils import get_test_config diff --git a/backend/onyx/agent_search/answer_initial_sub_question/models.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/models.py similarity index 86% rename from backend/onyx/agent_search/answer_initial_sub_question/models.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/models.py index 60bda54fc9..cb40261f0d 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/models.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_check.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/answer_check.py similarity index 77% rename from backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_check.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/answer_check.py index ab0481191d..2d8e8c4d16 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/answer_check.py @@ -1,8 +1,12 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState -from onyx.agent_search.answer_initial_sub_question.states import QACheckUpdate +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + QACheckUpdate, +) from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT diff --git a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/answer_generation.py similarity index 92% rename from backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/answer_generation.py index 1f6dca61b4..c89ae09485 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/answer_generation.py @@ -4,8 +4,12 @@ from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState -from onyx.agent_search.answer_initial_sub_question.states import QAGenerationUpdate +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + QAGenerationUpdate, +) from onyx.agent_search.shared_graph_utils.agent_prompt_ops import ( build_sub_question_answer_prompt, ) diff --git a/backend/onyx/agent_search/answer_initial_sub_question/nodes/format_answer.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/format_answer.py similarity index 64% rename from backend/onyx/agent_search/answer_initial_sub_question/nodes/format_answer.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/format_answer.py index de96f60614..b221147f76 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/format_answer.py @@ -1,6 +1,12 @@ -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState -from onyx.agent_search.answer_initial_sub_question.states import QuestionAnswerResults +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + QuestionAnswerResults, +) def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: diff --git a/backend/onyx/agent_search/answer_initial_sub_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py similarity index 79% rename from backend/onyx/agent_search/answer_initial_sub_question/nodes/ingest_retrieval.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py index dd74ba8d56..d348c0b55e 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/nodes/ingest_retrieval.py @@ -1,7 +1,9 @@ -from onyx.agent_search.answer_initial_sub_question.states import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( RetrievalIngestionUpdate, ) -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalOutput, +) from onyx.agent_search.shared_graph_utils.models import AgentChunkStats diff --git a/backend/onyx/agent_search/answer_initial_sub_question/states.py b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/states.py similarity index 89% rename from backend/onyx/agent_search/answer_initial_sub_question/states.py rename to backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/states.py index 74da2355ce..36d77a8c8a 100644 --- a/backend/onyx/agent_search/answer_initial_sub_question/states.py +++ b/backend/onyx/agent_search/pro_search_a/answer_initial_sub_question/states.py @@ -2,9 +2,11 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_initial_sub_question.models import QuestionAnswerResults from onyx.agent_search.core_state import SubgraphCoreState -from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import ( + QuestionAnswerResults, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/edges.py b/backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/edges.py similarity index 76% rename from backend/onyx/agent_search/answer_refinement_sub_question/edges.py rename to backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/edges.py index 5479d2c34f..daa403d230 100644 --- a/backend/onyx/agent_search/answer_refinement_sub_question/edges.py +++ b/backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/edges.py @@ -2,9 +2,13 @@ from langgraph.types import Send -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput from onyx.agent_search.core_state import in_subgraph_extract_core_fields -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) from onyx.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py b/backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/graph_builder.py similarity index 76% rename from backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py rename to backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/graph_builder.py index 85774bb600..35b7388af3 100644 --- a/backend/onyx/agent_search/answer_refinement_sub_question/graph_builder.py +++ b/backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/graph_builder.py @@ -2,25 +2,31 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_initial_sub_question.nodes.answer_check import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_check import ( answer_check, ) -from onyx.agent_search.answer_initial_sub_question.nodes.answer_generation import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.answer_generation import ( answer_generation, ) -from onyx.agent_search.answer_initial_sub_question.nodes.format_answer import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.format_answer import ( format_answer, ) -from onyx.agent_search.answer_initial_sub_question.nodes.ingest_retrieval import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.nodes.ingest_retrieval import ( ingest_retrieval, ) -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionState -from onyx.agent_search.answer_refinement_sub_question.edges import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionState, +) +from onyx.agent_search.pro_search_a.answer_refinement_sub_question.edges import ( send_to_expanded_refined_retrieval, ) -from onyx.agent_search.expanded_retrieval.graph_builder import ( +from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) from onyx.utils.logger import setup_logger diff --git a/backend/onyx/agent_search/answer_refinement_sub_question/models.py b/backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/models.py similarity index 100% rename from backend/onyx/agent_search/answer_refinement_sub_question/models.py rename to backend/onyx/agent_search/pro_search_a/answer_refinement_sub_question/models.py diff --git a/backend/onyx/agent_search/base_raw_search/graph_builder.py b/backend/onyx/agent_search/pro_search_a/base_raw_search/graph_builder.py similarity index 73% rename from backend/onyx/agent_search/base_raw_search/graph_builder.py rename to backend/onyx/agent_search/pro_search_a/base_raw_search/graph_builder.py index 5de90a8884..9a76a9ea10 100644 --- a/backend/onyx/agent_search/base_raw_search/graph_builder.py +++ b/backend/onyx/agent_search/pro_search_a/base_raw_search/graph_builder.py @@ -2,16 +2,16 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.base_raw_search.nodes.format_raw_search_results import ( +from onyx.agent_search.pro_search_a.base_raw_search.nodes.format_raw_search_results import ( format_raw_search_results, ) -from onyx.agent_search.base_raw_search.nodes.generate_raw_search_data import ( +from onyx.agent_search.pro_search_a.base_raw_search.nodes.generate_raw_search_data import ( generate_raw_search_data, ) -from onyx.agent_search.base_raw_search.states import BaseRawSearchInput -from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput -from onyx.agent_search.base_raw_search.states import BaseRawSearchState -from onyx.agent_search.expanded_retrieval.graph_builder import ( +from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchInput +from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchState +from onyx.agent_search.pro_search_a.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/base_raw_search/models.py b/backend/onyx/agent_search/pro_search_a/base_raw_search/models.py similarity index 86% rename from backend/onyx/agent_search/base_raw_search/models.py rename to backend/onyx/agent_search/pro_search_a/base_raw_search/models.py index 6ee67c9e36..8890011a19 100644 --- a/backend/onyx/agent_search/base_raw_search/models.py +++ b/backend/onyx/agent_search/pro_search_a/base_raw_search/models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py b/backend/onyx/agent_search/pro_search_a/base_raw_search/nodes/format_raw_search_results.py similarity index 69% rename from backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py rename to backend/onyx/agent_search/pro_search_a/base_raw_search/nodes/format_raw_search_results.py index 4acda010ef..47e7a99cb8 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/format_raw_search_results.py +++ b/backend/onyx/agent_search/pro_search_a/base_raw_search/nodes/format_raw_search_results.py @@ -1,5 +1,7 @@ -from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalOutput, +) from onyx.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py b/backend/onyx/agent_search/pro_search_a/base_raw_search/nodes/generate_raw_search_data.py similarity index 87% rename from backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py rename to backend/onyx/agent_search/pro_search_a/base_raw_search/nodes/generate_raw_search_data.py index 0aff2a4f70..a578fac0c8 100644 --- a/backend/onyx/agent_search/base_raw_search/nodes/generate_raw_search_data.py +++ b/backend/onyx/agent_search/pro_search_a/base_raw_search/nodes/generate_raw_search_data.py @@ -1,5 +1,7 @@ from onyx.agent_search.core_state import CoreState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) from onyx.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/onyx/agent_search/base_raw_search/states.py b/backend/onyx/agent_search/pro_search_a/base_raw_search/states.py similarity index 89% rename from backend/onyx/agent_search/base_raw_search/states.py rename to backend/onyx/agent_search/pro_search_a/base_raw_search/states.py index fb073454c4..57d825b734 100644 --- a/backend/onyx/agent_search/base_raw_search/states.py +++ b/backend/onyx/agent_search/pro_search_a/base_raw_search/states.py @@ -2,7 +2,9 @@ from onyx.agent_search.core_state import CoreState from onyx.agent_search.core_state import SubgraphCoreState -from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult +from onyx.agent_search.pro_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) ## Update States diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/edges.py similarity index 80% rename from backend/onyx/agent_search/expanded_retrieval/edges.py rename to backend/onyx/agent_search/pro_search_a/expanded_retrieval/edges.py index 73a0fc43ef..39ce93cbbf 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/edges.py @@ -3,8 +3,10 @@ from langgraph.types import Send from onyx.agent_search.core_state import in_subgraph_extract_core_fields -from onyx.agent_search.expanded_retrieval.nodes import RetrievalInput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import RetrievalInput +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]: diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/graph_builder.py similarity index 75% rename from backend/onyx/agent_search/expanded_retrieval/graph_builder.py rename to backend/onyx/agent_search/pro_search_a/expanded_retrieval/graph_builder.py index 5a225f1f94..b1b8a21ba1 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/graph_builder.py @@ -2,16 +2,24 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.expanded_retrieval.edges import parallel_retrieval_edge -from onyx.agent_search.expanded_retrieval.nodes import doc_reranking -from onyx.agent_search.expanded_retrieval.nodes import doc_retrieval -from onyx.agent_search.expanded_retrieval.nodes import doc_verification -from onyx.agent_search.expanded_retrieval.nodes import expand_queries -from onyx.agent_search.expanded_retrieval.nodes import format_results -from onyx.agent_search.expanded_retrieval.nodes import verification_kickoff -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.pro_search_a.expanded_retrieval.edges import ( + parallel_retrieval_edge, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_reranking +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_retrieval +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import doc_verification +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import expand_queries +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import format_results +from onyx.agent_search.pro_search_a.expanded_retrieval.nodes import verification_kickoff +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalOutput, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) from onyx.agent_search.shared_graph_utils.utils import get_test_config from onyx.utils.logger import setup_logger diff --git a/backend/onyx/agent_search/expanded_retrieval/models.py b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/models.py similarity index 100% rename from backend/onyx/agent_search/expanded_retrieval/models.py rename to backend/onyx/agent_search/pro_search_a/expanded_retrieval/models.py diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/nodes.py similarity index 92% rename from backend/onyx/agent_search/expanded_retrieval/nodes.py rename to backend/onyx/agent_search/pro_search_a/expanded_retrieval/nodes.py index e453f692f4..1ebefbd714 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/nodes.py @@ -11,18 +11,32 @@ from langgraph.types import Send from onyx.agent_search.core_state import in_subgraph_extract_core_fields -from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult -from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate -from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate -from onyx.agent_search.expanded_retrieval.states import DocVerificationInput -from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalUpdate -from onyx.agent_search.expanded_retrieval.states import InferenceSection -from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate -from onyx.agent_search.expanded_retrieval.states import RetrievalInput +from onyx.agent_search.pro_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult +from onyx.agent_search.pro_search_a.expanded_retrieval.states import DocRerankingUpdate +from onyx.agent_search.pro_search_a.expanded_retrieval.states import DocRetrievalUpdate +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + DocVerificationInput, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + DocVerificationUpdate, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalInput, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalState, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + ExpandedRetrievalUpdate, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import InferenceSection +from onyx.agent_search.pro_search_a.expanded_retrieval.states import ( + QueryExpansionUpdate, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.states import RetrievalInput from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/states.py similarity index 90% rename from backend/onyx/agent_search/expanded_retrieval/states.py rename to backend/onyx/agent_search/pro_search_a/expanded_retrieval/states.py index cfed4cc78a..86e5704a57 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/pro_search_a/expanded_retrieval/states.py @@ -3,8 +3,10 @@ from typing import TypedDict from onyx.agent_search.core_state import SubgraphCoreState -from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult -from onyx.agent_search.expanded_retrieval.models import QueryResult +from onyx.agent_search.pro_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/pro_search_a/main/edges.py similarity index 88% rename from backend/onyx/agent_search/main/edges.py rename to backend/onyx/agent_search/pro_search_a/main/edges.py index f36c808de8..d731d49398 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/pro_search_a/main/edges.py @@ -3,11 +3,15 @@ from langgraph.types import Send -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionInput -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput from onyx.agent_search.core_state import extract_core_fields_for_subgraph -from onyx.agent_search.main.states import MainState -from onyx.agent_search.main.states import RequireRefinedAnswerUpdate +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionInput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agent_search.pro_search_a.main.states import MainState +from onyx.agent_search.pro_search_a.main.states import RequireRefinedAnswerUpdate from onyx.agent_search.shared_graph_utils.utils import make_question_id from onyx.utils.logger import setup_logger diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/pro_search_a/main/graph_builder.py similarity index 80% rename from backend/onyx/agent_search/main/graph_builder.py rename to backend/onyx/agent_search/pro_search_a/main/graph_builder.py index a3c33f1dd2..7e78b0bd38 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/pro_search_a/main/graph_builder.py @@ -2,31 +2,37 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_initial_sub_question.graph_builder import ( +from onyx.agent_search.pro_search_a.answer_initial_sub_question.graph_builder import ( answer_query_graph_builder, ) -from onyx.agent_search.answer_refinement_sub_question.graph_builder import ( +from onyx.agent_search.pro_search_a.answer_refinement_sub_question.graph_builder import ( answer_refined_query_graph_builder, ) -from onyx.agent_search.base_raw_search.graph_builder import ( +from onyx.agent_search.pro_search_a.base_raw_search.graph_builder import ( base_raw_search_graph_builder, ) -from onyx.agent_search.main.edges import continue_to_refined_answer_or_end -from onyx.agent_search.main.edges import parallelize_initial_sub_question_answering -from onyx.agent_search.main.edges import parallelize_refined_sub_question_answering -from onyx.agent_search.main.nodes import agent_logging -from onyx.agent_search.main.nodes import entity_term_extraction_llm -from onyx.agent_search.main.nodes import generate_initial_answer -from onyx.agent_search.main.nodes import generate_refined_answer -from onyx.agent_search.main.nodes import ingest_initial_base_retrieval -from onyx.agent_search.main.nodes import ingest_initial_sub_question_answers -from onyx.agent_search.main.nodes import ingest_refined_answers -from onyx.agent_search.main.nodes import initial_answer_quality_check -from onyx.agent_search.main.nodes import initial_sub_question_creation -from onyx.agent_search.main.nodes import refined_answer_decision -from onyx.agent_search.main.nodes import refined_sub_question_creation -from onyx.agent_search.main.states import MainInput -from onyx.agent_search.main.states import MainState +from onyx.agent_search.pro_search_a.main.edges import continue_to_refined_answer_or_end +from onyx.agent_search.pro_search_a.main.edges import ( + parallelize_initial_sub_question_answering, +) +from onyx.agent_search.pro_search_a.main.edges import ( + parallelize_refined_sub_question_answering, +) +from onyx.agent_search.pro_search_a.main.nodes import agent_logging +from onyx.agent_search.pro_search_a.main.nodes import entity_term_extraction_llm +from onyx.agent_search.pro_search_a.main.nodes import generate_initial_answer +from onyx.agent_search.pro_search_a.main.nodes import generate_refined_answer +from onyx.agent_search.pro_search_a.main.nodes import ingest_initial_base_retrieval +from onyx.agent_search.pro_search_a.main.nodes import ( + ingest_initial_sub_question_answers, +) +from onyx.agent_search.pro_search_a.main.nodes import ingest_refined_answers +from onyx.agent_search.pro_search_a.main.nodes import initial_answer_quality_check +from onyx.agent_search.pro_search_a.main.nodes import initial_sub_question_creation +from onyx.agent_search.pro_search_a.main.nodes import refined_answer_decision +from onyx.agent_search.pro_search_a.main.nodes import refined_sub_question_creation +from onyx.agent_search.pro_search_a.main.states import MainInput +from onyx.agent_search.pro_search_a.main.states import MainState from onyx.agent_search.shared_graph_utils.utils import get_test_config from onyx.utils.logger import setup_logger diff --git a/backend/onyx/agent_search/main/models.py b/backend/onyx/agent_search/pro_search_a/main/models.py similarity index 100% rename from backend/onyx/agent_search/main/models.py rename to backend/onyx/agent_search/pro_search_a/main/models.py diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/pro_search_a/main/nodes.py similarity index 94% rename from backend/onyx/agent_search/main/nodes.py rename to backend/onyx/agent_search/pro_search_a/main/nodes.py index 6db39a5258..5ef408ae1b 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/pro_search_a/main/nodes.py @@ -10,31 +10,35 @@ from langchain_core.messages import merge_content from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_initial_sub_question.states import AnswerQuestionOutput -from onyx.agent_search.answer_initial_sub_question.states import QuestionAnswerResults -from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput -from onyx.agent_search.main.models import AgentAdditionalMetrics -from onyx.agent_search.main.models import AgentBaseMetrics -from onyx.agent_search.main.models import AgentRefinedMetrics -from onyx.agent_search.main.models import AgentTimings -from onyx.agent_search.main.models import CombinedAgentMetrics -from onyx.agent_search.main.models import Entity -from onyx.agent_search.main.models import EntityRelationshipTermExtraction -from onyx.agent_search.main.models import FollowUpSubQuestion -from onyx.agent_search.main.models import Relationship -from onyx.agent_search.main.models import Term -from onyx.agent_search.main.states import BaseDecompUpdate -from onyx.agent_search.main.states import DecompAnswersUpdate -from onyx.agent_search.main.states import EntityTermExtractionUpdate -from onyx.agent_search.main.states import ExpandedRetrievalUpdate -from onyx.agent_search.main.states import FollowUpSubQuestionsUpdate -from onyx.agent_search.main.states import InitialAnswerBASEUpdate -from onyx.agent_search.main.states import InitialAnswerQualityUpdate -from onyx.agent_search.main.states import InitialAnswerUpdate -from onyx.agent_search.main.states import MainOutput -from onyx.agent_search.main.states import MainState -from onyx.agent_search.main.states import RefinedAnswerUpdate -from onyx.agent_search.main.states import RequireRefinedAnswerUpdate +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + AnswerQuestionOutput, +) +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + QuestionAnswerResults, +) +from onyx.agent_search.pro_search_a.base_raw_search.states import BaseRawSearchOutput +from onyx.agent_search.pro_search_a.main.models import AgentAdditionalMetrics +from onyx.agent_search.pro_search_a.main.models import AgentBaseMetrics +from onyx.agent_search.pro_search_a.main.models import AgentRefinedMetrics +from onyx.agent_search.pro_search_a.main.models import AgentTimings +from onyx.agent_search.pro_search_a.main.models import CombinedAgentMetrics +from onyx.agent_search.pro_search_a.main.models import Entity +from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction +from onyx.agent_search.pro_search_a.main.models import FollowUpSubQuestion +from onyx.agent_search.pro_search_a.main.models import Relationship +from onyx.agent_search.pro_search_a.main.models import Term +from onyx.agent_search.pro_search_a.main.states import BaseDecompUpdate +from onyx.agent_search.pro_search_a.main.states import DecompAnswersUpdate +from onyx.agent_search.pro_search_a.main.states import EntityTermExtractionUpdate +from onyx.agent_search.pro_search_a.main.states import ExpandedRetrievalUpdate +from onyx.agent_search.pro_search_a.main.states import FollowUpSubQuestionsUpdate +from onyx.agent_search.pro_search_a.main.states import InitialAnswerBASEUpdate +from onyx.agent_search.pro_search_a.main.states import InitialAnswerQualityUpdate +from onyx.agent_search.pro_search_a.main.states import InitialAnswerUpdate +from onyx.agent_search.pro_search_a.main.states import MainOutput +from onyx.agent_search.pro_search_a.main.states import MainState +from onyx.agent_search.pro_search_a.main.states import RefinedAnswerUpdate +from onyx.agent_search.pro_search_a.main.states import RequireRefinedAnswerUpdate from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/pro_search_a/main/states.py similarity index 85% rename from backend/onyx/agent_search/main/states.py rename to backend/onyx/agent_search/pro_search_a/main/states.py index 3e9bca346d..e9a548a10c 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/pro_search_a/main/states.py @@ -3,14 +3,18 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_initial_sub_question.states import QuestionAnswerResults from onyx.agent_search.core_state import CoreState -from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult -from onyx.agent_search.expanded_retrieval.models import QueryResult -from onyx.agent_search.main.models import AgentBaseMetrics -from onyx.agent_search.main.models import AgentRefinedMetrics -from onyx.agent_search.main.models import EntityRelationshipTermExtraction -from onyx.agent_search.main.models import FollowUpSubQuestion +from onyx.agent_search.pro_search_a.answer_initial_sub_question.states import ( + QuestionAnswerResults, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.models import ( + ExpandedRetrievalResult, +) +from onyx.agent_search.pro_search_a.expanded_retrieval.models import QueryResult +from onyx.agent_search.pro_search_a.main.models import AgentBaseMetrics +from onyx.agent_search.pro_search_a.main.models import AgentRefinedMetrics +from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction +from onyx.agent_search.pro_search_a.main.models import FollowUpSubQuestion from onyx.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agent_search.shared_graph_utils.models import RefinedAgentStats diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 39c2d4a2e9..cb47ac622f 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -9,9 +9,9 @@ from langgraph.graph.state import CompiledStateGraph from sqlalchemy.orm import Session -from onyx.agent_search.main.graph_builder import main_graph_builder -from onyx.agent_search.main.states import MainInput from onyx.agent_search.models import AgentDocumentCitations +from onyx.agent_search.pro_search_a.main.graph_builder import main_graph_builder +from onyx.agent_search.pro_search_a.main.states import MainInput from onyx.agent_search.shared_graph_utils.utils import get_test_config from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AnswerPacket diff --git a/backend/onyx/agent_search/shared_graph_utils/operators.py b/backend/onyx/agent_search/shared_graph_utils/operators.py index b3636b81ca..cc98c11a18 100644 --- a/backend/onyx/agent_search/shared_graph_utils/operators.py +++ b/backend/onyx/agent_search/shared_graph_utils/operators.py @@ -1,4 +1,6 @@ -from onyx.agent_search.answer_initial_sub_question.models import QuestionAnswerResults +from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import ( + QuestionAnswerResults, +) from onyx.chat.prune_and_merge import _merge_sections from onyx.context.search.models import InferenceSection diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py index e995b829d9..a6ca3c26df 100644 --- a/backend/onyx/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -13,7 +13,7 @@ from langchain_core.messages import BaseMessage from sqlalchemy.orm import Session -from onyx.agent_search.main.models import EntityRelationshipTermExtraction +from onyx.agent_search.pro_search_a.main.models import EntityRelationshipTermExtraction from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import DocumentPruningConfig diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index 4a446c7829..a4af34d9cb 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -16,8 +16,10 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session -from onyx.agent_search.answer_initial_sub_question.models import QuestionAnswerResults -from onyx.agent_search.main.models import CombinedAgentMetrics +from onyx.agent_search.pro_search_a.answer_initial_sub_question.models import ( + QuestionAnswerResults, +) +from onyx.agent_search.pro_search_a.main.models import CombinedAgentMetrics from onyx.auth.schemas import UserRole from onyx.chat.models import DocumentRelevance from onyx.configs.chat_configs import HARD_DELETE_CHATS