Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""single tool call per message

Revision ID: 33cb72ea4d80
Revises: 949b4a92a401
Create Date: 2024-11-01 12:51:01.535003

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "949b4a92a401"
branch_labels = None
depends_on = None


def upgrade() -> None:
# 1. Add 'message_id' column to 'tool_call' table
op.add_column("tool_call", sa.Column("message_id", sa.Integer(), nullable=True))

# 2. Create foreign key constraint from 'tool_call.message_id' to 'chat_message.id'
op.create_foreign_key(
"fk_tool_call_message_id",
"tool_call",
"chat_message",
["message_id"],
["id"],
)

# 3. Migrate existing data from 'chat_message.tool_call_id' to 'tool_call.message_id'
op.execute(
"""
UPDATE tool_call
SET message_id = chat_message.id
FROM chat_message
WHERE chat_message.tool_call_id = tool_call.id
"""
)

# 4. Drop the foreign key constraint and column 'tool_call_id' from 'chat_message' table
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

# 5. Optionally drop the unique constraint if it was previously added
# op.drop_constraint("uq_chat_message_tool_call_id", "chat_message", type_="unique")


def downgrade() -> None:
# 1. Add 'tool_call_id' column back to 'chat_message' table
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)

# 2. Restore foreign key constraint from 'chat_message.tool_call_id' to 'tool_call.id'
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)

# 3. Migrate data back from 'tool_call.message_id' to 'chat_message.tool_call_id'
op.execute(
"""
UPDATE chat_message
SET tool_call_id = tool_call.id
FROM tool_call
WHERE tool_call.message_id = chat_message.id
"""
)

# 4. Drop the foreign key constraint and column 'message_id' from 'tool_call' table
op.drop_constraint("fk_tool_call_message_id", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
24 changes: 12 additions & 12 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,9 @@ def stream_chat_message_objects(
tool_result = None

for packet in answer.processed_streamed_output:
if isinstance(packet, StreamStopInfo):
break

if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
Expand Down Expand Up @@ -805,8 +808,7 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
pass

else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
Expand Down Expand Up @@ -864,17 +866,15 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
if tool_result
else []
else None
),
)

Expand Down
23 changes: 11 additions & 12 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_chat_messages_by_session(
)

if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
stmt = stmt.options(joinedload(ChatMessage.tool_call))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
Expand Down Expand Up @@ -474,7 +474,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
tool_call: ToolCall | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
Expand All @@ -494,7 +494,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.tool_call = tool_call
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
Expand All @@ -513,7 +513,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
tool_call=tool_call,
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
Expand Down Expand Up @@ -749,14 +749,13 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
Expand Down
17 changes: 12 additions & 5 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,10 +917,15 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())

message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=True
)

# Update the relationship
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_calls"
"ChatMessage",
back_populates="tool_call",
uselist=False,
)


Expand Down Expand Up @@ -1051,12 +1056,14 @@ class ChatMessage(Base):
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(

tool_call: Mapped["ToolCall"] = relationship(
"ToolCall",
back_populates="message",
uselist=False,
# foreign_keys=[ToolCall.message_id], # Specify foreign key if needed
)

standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,
Expand Down
17 changes: 8 additions & 9 deletions backend/danswer/llm/answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None

@classmethod
def from_chat_message(
Expand All @@ -51,14 +51,13 @@ def from_chat_message(
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)

def to_langchain_msg(self) -> BaseMessage:
Expand Down
6 changes: 4 additions & 2 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ def _convert_litellm_message_to_langchain_message(
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in (tool_calls if tool_calls else [])
],
for tool_call in tool_calls
]
if tool_calls
else [],
)
elif role == "system":
return SystemMessage(content=content)
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/server/query_and_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class ChatMessageDetail(BaseModel):
chat_session_id: UUID | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None

def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
Expand Down
48 changes: 21 additions & 27 deletions web/src/app/chat/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ export function ChatPage({
if (
lastMessage &&
lastMessage.type === "assistant" &&
lastMessage.toolCalls[0] &&
lastMessage.toolCalls[0].tool_result === undefined
lastMessage.toolCall &&
lastMessage.toolCall.tool_result === undefined
) {
const newCompleteMessageMap = new Map(
currentMessageMap(completeMessageDetail)
);
const updatedMessage = { ...lastMessage, toolCalls: [] };
const updatedMessage = { ...lastMessage, toolCall: null };
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
}
Expand Down Expand Up @@ -513,7 +513,7 @@ export function ChatPage({
message: "",
type: "system",
files: [],
toolCalls: [],
toolCall: null,
parentMessageId: null,
childrenMessageIds: [firstMessageId],
latestChildMessageId: firstMessageId,
Expand Down Expand Up @@ -1104,7 +1104,7 @@ export function ChatPage({
let stackTrace: string | null = null;

let finalMessage: BackendMessage | null = null;
let toolCalls: ToolCallMetadata[] = [];
let toolCall: ToolCallMetadata | null = null;

let initialFetchDetails: null | {
user_message_id: number;
Expand Down Expand Up @@ -1209,7 +1209,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
];
Expand Down Expand Up @@ -1262,26 +1262,23 @@ export function ChatPage({
setSelectedMessageForDocDisplay(user_message_id);
}
} else if (Object.hasOwn(packet, "tool_name")) {
toolCalls = [
{
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
},
];
if (
!toolCalls[0].tool_result ||
toolCalls[0].tool_result == undefined
) {
// Will only ever be one tool call per message
toolCall = {
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
};

if (!toolCall.tool_result || toolCall.tool_result == undefined) {
updateChatState("toolBuilding", frozenSessionId);
} else {
updateChatState("streaming", frozenSessionId);
}

// This will be consolidated in upcoming tool calls udpate,
// but for now, we need to set query as early as possible
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
query = toolCalls[0].tool_args["query"];
if (toolCall.tool_name == SEARCH_TOOL_NAME) {
query = toolCall.tool_args["query"];
}
} else if (Object.hasOwn(packet, "file_ids")) {
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
Expand Down Expand Up @@ -1339,7 +1336,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
Expand All @@ -1358,7 +1355,7 @@ export function ChatPage({
finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
toolCall: finalMessage?.tool_call || null,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id,
Expand All @@ -1381,7 +1378,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
{
Expand All @@ -1391,7 +1388,7 @@ export function ChatPage({
message: errorMsg,
type: "error",
files: aiMessageImages || [],
toolCalls: [],
toolCall: null,
parentMessageId:
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
},
Expand Down Expand Up @@ -2237,10 +2234,7 @@ export function ChatPage({
citedDocuments={getCitedDocumentsFromMessage(
message
)}
toolCall={
message.toolCalls &&
message.toolCalls[0]
}
toolCall={message.toolCall}
isComplete={
i !== messageHistory.length - 1 ||
(currentSessionChatState !=
Expand Down
Loading