From ec4617e3d53993319318eeb3780c75495bed48ec Mon Sep 17 00:00:00 2001 From: jakubduda-dsai Date: Thu, 30 Oct 2025 11:15:50 +0100 Subject: [PATCH 1/2] init confirm agent tool codebase --- examples/agents/agent_with_confirmation.py | 120 +++++++++++++ examples/chat/calendar_agent.py | 162 ++++++++++++++++++ .../src/ragbits/agents/_main.py | 153 ++++++++++++++--- .../src/ragbits/agents/confirmation.py | 108 ++++++++++++ .../ragbits-agents/src/ragbits/agents/tool.py | 6 +- packages/ragbits-chat/src/ragbits/chat/api.py | 23 ++- .../src/ragbits/chat/interface/types.py | 9 + .../ragbits/chat/providers/model_provider.py | 3 + .../@ragbits/api-client/src/autogen.types.ts | 20 +++ typescript/@ragbits/api-client/src/types.ts | 1 + .../components/ChatMessage/ChatMessage.tsx | 47 ++++- .../ChatMessage/ConfirmationDialog.tsx | 89 ++++++++++ .../eventHandlers/eventHandlerRegistry.ts | 4 + .../eventHandlers/messageHandlers.ts | 8 + typescript/ui/src/types/history.ts | 2 + 15 files changed, 724 insertions(+), 31 deletions(-) create mode 100644 examples/agents/agent_with_confirmation.py create mode 100644 examples/chat/calendar_agent.py create mode 100644 packages/ragbits-agents/src/ragbits/agents/confirmation.py create mode 100644 typescript/ui/src/core/components/ChatMessage/ConfirmationDialog.tsx diff --git a/examples/agents/agent_with_confirmation.py b/examples/agents/agent_with_confirmation.py new file mode 100644 index 000000000..4c2722e34 --- /dev/null +++ b/examples/agents/agent_with_confirmation.py @@ -0,0 +1,120 @@ +""" +Example demonstrating an agent with tool confirmation. + +This example shows how to create an agent that requires user confirmation +before executing certain tools. + +Run this example: + uv run python examples/agents/agent_with_confirmation.py +""" + +import asyncio +from types import SimpleNamespace + +from ragbits.agents import Agent +from ragbits.agents._main import AgentDependencies, AgentRunContext +from ragbits.agents.confirmation import ConfirmationManager, ConfirmationRequest +from ragbits.core.llms import LiteLLM + + +# Define some example tools +def get_weather(city: str) -> str: + """ + Get the weather for a city. + + Args: + city: The city to get weather for + """ + return f"ā˜€ļø Weather in {city}: Sunny, 72°F" + + +def send_email(to: str, subject: str, body: str) -> str: + """ + Send an email to someone. + + Args: + to: Email recipient + subject: Email subject + body: Email body + """ + return f"šŸ“§ Email sent to {to} with subject '{subject}'" + + +def delete_file(filename: str) -> str: + """ + Delete a file from the system. + + Args: + filename: The file to delete + """ + return f"šŸ—‘ļø Deleted file: {filename}" + + +async def main() -> None: + """Run the agent with confirmation example.""" + # Create LLM + llm = LiteLLM(model_name="gpt-4o-mini") + + # Create agent + agent: Agent = Agent( + llm=llm, + prompt="You are a helpful assistant. Help the user with their requests.", + tools=[get_weather, send_email, delete_file], + ) + + # Mark tools that require confirmation + for tool in agent.tools: + if tool.name in ["send_email", "delete_file"]: + tool.requires_confirmation = True + print(f"āœ“ Tool '{tool.name}' marked as requiring confirmation") + + # Create confirmation manager + confirmation_manager = ConfirmationManager() + + # Create agent context with confirmation manager + deps_value = SimpleNamespace(confirmation_manager=confirmation_manager) + agent_context = AgentRunContext(deps=AgentDependencies(value=deps_value)) + + print("\n" + "=" * 60) + print("Agent with Confirmation Example") + print("=" * 60) + print("\nTools available:") + print(" - get_weather (no confirmation)") + print(" - send_email (requires confirmation)") + print(" - delete_file (requires confirmation)") + print("\nTry: 'Send an email to john@example.com about the meeting'") + print("=" * 60 + "\n") + + # Test query + user_query = "Send an email to john@example.com with subject 'Meeting Reminder' about our 2pm meeting tomorrow" + + print(f"User: {user_query}\n") + + # Stream agent responses + async for response in agent.run_streaming(user_query, context=agent_context): + if isinstance(response, str): + print(f"Agent: {response}", end="", flush=True) + + elif isinstance(response, ConfirmationRequest): + print("\n\nāš ļø CONFIRMATION REQUIRED āš ļø") + print(f"Tool: {response.tool_name}") + print(f"Description: {response.tool_description}") + print(f"Arguments: {response.arguments}") + print(f"Timeout: {response.timeout_seconds}s") + + # Simulate user input + user_input = input("\nDo you want to proceed? (yes/no): ").strip().lower() + + confirmed = user_input in ["yes", "y"] + confirmation_manager.resolve_confirmation(response.confirmation_id, confirmed) + + if confirmed: + print("āœ… Confirmed - proceeding with action\n") + else: + print("āŒ Cancelled - skipping action\n") + + print("\n\nDone!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/chat/calendar_agent.py b/examples/chat/calendar_agent.py new file mode 100644 index 000000000..02bff3e07 --- /dev/null +++ b/examples/chat/calendar_agent.py @@ -0,0 +1,162 @@ +""" +Ragbits Chat Example: Calendar Agent with Confirmation + +This example demonstrates how to use the ChatInterface with an agent that requires +user confirmation for destructive actions like deleting events or inviting people. + +To run the script: + ragbits api run examples.chat.calendar_agent:CalendarChat +""" + +from collections.abc import AsyncGenerator +from types import SimpleNamespace + +from ragbits.agents import Agent, ToolCallResult +from ragbits.agents._main import AgentDependencies, AgentRunContext, DownstreamAgentResult +from ragbits.agents.confirmation import ConfirmationManager, ConfirmationRequest +from ragbits.chat.interface import ChatInterface +from ragbits.chat.interface.types import ChatContext, ChatResponse, ChatResponseType, LiveUpdateType +from ragbits.chat.interface.ui_customization import HeaderCustomization, UICustomization +from ragbits.core.llms import LiteLLM, ToolCall +from ragbits.core.llms.base import Usage +from ragbits.core.prompt import ChatFormat +from ragbits.core.prompt.base import BasePrompt + + +# Define calendar tools +def analyze_calendar() -> str: + """Analyze the user's calendar and provide insights.""" + return "šŸ“Š Calendar analyzed: You have 5 meetings this week, 2 tomorrow" + + +def get_meetings(date: str = "today") -> str: + """ + Get meetings for a specific date. + + Args: + date: Date to get meetings for (today, tomorrow, or YYYY-MM-DD) + """ + return f"šŸ“… Meetings for {date}: Team sync (2pm), Client call (4pm)" + + +def invite_people(email: str, event_id: str, message: str = "") -> str: + """ + Invite people to a calendar event. + + Args: + email: Email address of the person to invite + event_id: ID of the event + message: Optional message to include + """ + return f"āœ‰ļø Successfully invited {email} to event {event_id}" + + +def delete_event(event_id: str, reason: str = "") -> str: + """ + Delete a calendar event. + + Args: + event_id: ID of the event to delete + reason: Optional reason for deletion + """ + return f"šŸ—‘ļø Successfully deleted event {event_id}" + + +# Type alias for response types +ResponseType = ( + str | ToolCall | ToolCallResult | ConfirmationRequest | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult +) + + +class CalendarChat(ChatInterface): + """Calendar agent with confirmation for destructive actions.""" + + ui_customization = UICustomization( + header=HeaderCustomization( + title="Calendar Assistant", subtitle="with confirmation for important actions", logo="šŸ“…" + ), + welcome_message=( + "Hello! I'm your calendar assistant.\n\n" + "I can help you manage your calendar, but I'll ask for confirmation " + "before deleting events or inviting people." + ), + ) + + conversation_history = True + show_usage = True + + def __init__(self) -> None: + self.llm = LiteLLM(model_name="gpt-4o-mini") + + # Create agent with tools marked for confirmation + self.agent: Agent = Agent( + llm=self.llm, + prompt=""" + You are a helpful calendar assistant. Help users manage their calendar by: + - Analyzing their schedule + - Showing meetings + - Inviting people to events + - Deleting events when requested + + Always be clear about what actions you're taking. + """, + tools=[ + analyze_calendar, + get_meetings, + invite_people, + delete_event, + ], + keep_history=True, + ) + # Mark specific tools as requiring confirmation + for tool in self.agent.tools: + if tool.name in ["invite_people", "delete_event"]: + tool.requires_confirmation = True + + async def chat( + self, + message: str, + history: ChatFormat, + context: ChatContext, + ) -> AsyncGenerator[ChatResponse, None]: + """Chat implementation with confirmation support.""" + # Get the confirmation manager from context (provided by RagbitsAPI) + confirmation_manager: ConfirmationManager = context.confirmation_manager # type: ignore[attr-defined] + + # Create a simple namespace to hold our dependencies + deps_value = SimpleNamespace(confirmation_manager=confirmation_manager) + agent_context: AgentRunContext = AgentRunContext(deps=AgentDependencies(value=deps_value)) + + # Run agent in streaming mode with just the message (agent handles history internally) + async for response in self.agent.run_streaming( + message, + context=agent_context, + ): + # Pattern match on response types + match response: + case str(): + # Regular text response + if response.strip(): + yield self.create_text_response(response) + + case ToolCall(): + # Tool is being called + yield self.create_live_update(response.id, LiveUpdateType.START, f"šŸ”§ {response.name}") + + case ConfirmationRequest(): + # Confirmation needed - send to frontend + yield ChatResponse( + type=ChatResponseType.CONFIRMATION_REQUEST, + content=response, + ) + + case ToolCallResult(): + # Tool execution completed + result_preview = str(response.result)[:50] + yield self.create_live_update( + response.id, LiveUpdateType.FINISH, f"āœ… {response.name}", result_preview + ) + + case Usage(): + # Usage information + yield self.create_usage_response(response) diff --git a/packages/ragbits-agents/src/ragbits/agents/_main.py b/packages/ragbits-agents/src/ragbits/agents/_main.py index 43783c1c4..4aa9e3246 100644 --- a/packages/ragbits-agents/src/ragbits/agents/_main.py +++ b/packages/ragbits-agents/src/ragbits/agents/_main.py @@ -14,10 +14,12 @@ from pydantic import ( BaseModel, Field, + PrivateAttr, ) from typing_extensions import Self from ragbits import agents +from ragbits.agents.confirmation import ConfirmationManager, ConfirmationRequest from ragbits.agents.exceptions import ( AgentInvalidPostProcessorError, AgentInvalidPromptInputError, @@ -78,6 +80,7 @@ class DownstreamAgentResult: str, ToolCall, ToolCallResult, + ConfirmationRequest, "DownstreamAgentResult", BasePrompt, Usage, @@ -144,23 +147,27 @@ class AgentDependencies(BaseModel, Generic[DepsT]): model_config = {"arbitrary_types_allowed": True} - _frozen: bool - _value: DepsT | None + _frozen: bool = PrivateAttr(default=False) + _value: DepsT | None = PrivateAttr(default=None) - def __init__(self, value: DepsT | None = None) -> None: - super().__init__() + def __init__(self, value: DepsT | None = None, **data) -> None: # type: ignore[no-untyped-def] + super().__init__(**data) self._value = value self._frozen = False def __setattr__(self, name: str, value: object) -> None: - is_frozen = False - if name != "_frozen": - try: - is_frozen = object.__getattribute__(self, "_frozen") - except AttributeError: - is_frozen = False + # Check if we're frozen, but allow setting private attributes during init + if name in ("_frozen", "_value"): + super().__setattr__(name, value) + return + + try: + pydantic_private = object.__getattribute__(self, "__pydantic_private__") + is_frozen = pydantic_private.get("_frozen", False) + except AttributeError: + is_frozen = False - if is_frozen and name not in {"_frozen"}: + if is_frozen: raise RuntimeError("Dependencies are immutable after first access") super().__setattr__(name, value) @@ -171,23 +178,34 @@ def value(self) -> DepsT | None: @value.setter def value(self, value: DepsT) -> None: - if self._frozen: + # Access _frozen from __pydantic_private__ to avoid recursion + pydantic_private = object.__getattribute__(self, "__pydantic_private__") + if pydantic_private.get("_frozen"): raise RuntimeError("Dependencies are immutable after first access") - self._value = value + pydantic_private["_value"] = value def _freeze(self) -> None: - if not self._frozen: - self._frozen = True + # Access _frozen from __pydantic_private__ to avoid recursion + pydantic_private = object.__getattribute__(self, "__pydantic_private__") + if not pydantic_private.get("_frozen"): + pydantic_private["_frozen"] = True def __getattr__(self, name: str) -> object: - value = object.__getattribute__(self, "_value") + # Access _value from __pydantic_private__ to avoid recursion + pydantic_private = object.__getattribute__(self, "__pydantic_private__") + value = pydantic_private.get("_value") if value is None: raise AttributeError(name) self._freeze() return getattr(value, name) def __contains__(self, key: str) -> bool: - value = object.__getattribute__(self, "_value") + # Access _value from __pydantic_private__ to avoid recursion + try: + pydantic_private = object.__getattribute__(self, "__pydantic_private__") + value = pydantic_private.get("_value") + except AttributeError: + return False return hasattr(value, key) if value is not None else False @@ -228,7 +246,16 @@ def get_agent(self, agent_id: str) -> "Agent | None": class AgentResultStreaming( - AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult] + AsyncIterator[ + str + | ToolCall + | ToolCallResult + | ConfirmationRequest + | BasePrompt + | Usage + | SimpleNamespace + | DownstreamAgentResult + ] ): """ An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned @@ -239,7 +266,15 @@ class AgentResultStreaming( def __init__( self, generator: AsyncGenerator[ - str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage + str + | ToolCall + | ToolCallResult + | ConfirmationRequest + | DownstreamAgentResult + | SimpleNamespace + | BasePrompt + | Usage, + None, ], ): self._generator = generator @@ -252,12 +287,30 @@ def __init__( def __aiter__( self, - ) -> AsyncIterator[str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult]: + ) -> AsyncIterator[ + str + | ToolCall + | ToolCallResult + | ConfirmationRequest + | BasePrompt + | Usage + | SimpleNamespace + | DownstreamAgentResult + ]: return self - async def __anext__( + async def __anext__( # noqa: PLR0912 self, - ) -> str | ToolCall | ToolCallResult | BasePrompt | Usage | SimpleNamespace | DownstreamAgentResult: + ) -> ( + str + | ToolCall + | ToolCallResult + | ConfirmationRequest + | BasePrompt + | Usage + | SimpleNamespace + | DownstreamAgentResult + ): try: item = await self._generator.__anext__() @@ -270,6 +323,9 @@ async def __anext__( if self.tool_calls is None: self.tool_calls = [] self.tool_calls.append(item) + case ConfirmationRequest(): + # Pass through confirmation requests to the caller + pass case DownstreamAgentResult(): if item.agent_id not in self.downstream: self.downstream[item.agent_id] = [] @@ -628,7 +684,17 @@ async def _stream_internal( # noqa: PLR0912 options: AgentOptions[LLMClientOptionsT] | None = None, context: AgentRunContext | None = None, tool_choice: ToolChoice | None = None, - ) -> AsyncGenerator[str | ToolCall | ToolCallResult | DownstreamAgentResult | SimpleNamespace | BasePrompt | Usage]: + ) -> AsyncGenerator[ + str + | ToolCall + | ToolCallResult + | DownstreamAgentResult + | ConfirmationRequest + | SimpleNamespace + | BasePrompt + | Usage, + None, + ]: if context is None: context = AgentRunContext() @@ -698,7 +764,7 @@ async def _execute_tool_calls( tools_mapping: dict[str, Tool], context: AgentRunContext, parallel_tool_calling: bool, - ) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult, None]: + ) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult | ConfirmationRequest, None]: """Execute tool calls either in parallel or sequentially based on `parallel_tool_calling` value.""" if parallel_tool_calling: queue: asyncio.Queue = asyncio.Queue() @@ -838,12 +904,12 @@ async def _get_all_tools(self) -> dict[str, Tool]: return tools_mapping - async def _execute_tool( + async def _execute_tool( # noqa: PLR0912, PLR0915 self, tool_call: ToolCall, tools_mapping: dict[str, Tool], context: AgentRunContext, - ) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult, None]: + ) -> AsyncGenerator[ToolCallResult | DownstreamAgentResult | ConfirmationRequest, None]: if tool_call.type != "function": raise AgentToolNotSupportedError(tool_call.type) if tool_call.name not in tools_mapping: @@ -851,6 +917,41 @@ async def _execute_tool( tool = tools_mapping[tool_call.name] + # Check if tool requires confirmation + if tool.requires_confirmation: + confirmation_manager: ConfirmationManager | None = None + # Use __contains__ (in operator) instead of hasattr for AgentDependencies + if "confirmation_manager" in context.deps: + with suppress(AttributeError): + confirmation_manager = context.deps.confirmation_manager # type: ignore[assignment, attr-defined] + + if confirmation_manager: + # Request confirmation + request, future = await confirmation_manager.request_confirmation( + tool_name=tool_call.name, + tool_description=tool.description or "", + arguments=tool_call.arguments, + ) + + # Yield confirmation request (will be streamed to frontend) + yield request + + # Wait for user response + try: + confirmed = await asyncio.wait_for(future, timeout=65) + except asyncio.TimeoutError: + confirmed = False + + if not confirmed: + # User denied or timeout - return cancelled result + yield ToolCallResult( + id=tool_call.id, + name=tool_call.name, + arguments=tool_call.arguments, + result="āŒ Action cancelled by user", + ) + return + with trace(agent_id=self.id, tool_name=tool_call.name, tool_arguments=tool_call.arguments) as outputs: try: call_args = tool_call.arguments.copy() diff --git a/packages/ragbits-agents/src/ragbits/agents/confirmation.py b/packages/ragbits-agents/src/ragbits/agents/confirmation.py new file mode 100644 index 000000000..467dd0029 --- /dev/null +++ b/packages/ragbits-agents/src/ragbits/agents/confirmation.py @@ -0,0 +1,108 @@ +""" +Tool confirmation functionality for agents. + +This module provides the ability to request user confirmation before executing certain tools. +""" + +import asyncio +import uuid +from typing import Any + +from pydantic import BaseModel + + +class ConfirmationRequest(BaseModel): + """Represents a tool confirmation request sent to the user.""" + + confirmation_id: str + """Unique identifier for this confirmation request.""" + tool_name: str + """Name of the tool requiring confirmation.""" + tool_description: str + """Description of what the tool does.""" + arguments: dict[str, Any] + """Arguments that will be passed to the tool.""" + timeout_seconds: int = 60 + """Timeout in seconds before auto-denying the request.""" + + +class ConfirmationManager: + """ + Manages pending tool confirmations. + + This manager tracks confirmation requests and their associated futures, + allowing the agent to pause execution while waiting for user input. + """ + + def __init__(self) -> None: + """Initialize the confirmation manager.""" + self._pending: dict[str, asyncio.Future[bool]] = {} + + async def request_confirmation( + self, + tool_name: str, + tool_description: str, + arguments: dict[str, Any], + timeout_seconds: int = 60, + ) -> tuple[ConfirmationRequest, asyncio.Future[bool]]: + """ + Request confirmation for a tool execution. + + Args: + tool_name: Name of the tool requiring confirmation. + tool_description: Description of what the tool does. + arguments: Arguments that will be passed to the tool. + timeout_seconds: Timeout in seconds before auto-denying. + + Returns: + Tuple of (confirmation_request, future that resolves to True/False). + """ + confirmation_id = str(uuid.uuid4()) + future: asyncio.Future[bool] = asyncio.Future() + + self._pending[confirmation_id] = future + + # Set timeout + asyncio.create_task(self._handle_timeout(confirmation_id, timeout_seconds)) + + request = ConfirmationRequest( + confirmation_id=confirmation_id, + tool_name=tool_name, + tool_description=tool_description, + arguments=arguments, + timeout_seconds=timeout_seconds, + ) + + return request, future + + async def _handle_timeout(self, confirmation_id: str, timeout_seconds: int) -> None: + """ + Handle timeout for a confirmation request. + + Args: + confirmation_id: ID of the confirmation request. + timeout_seconds: Timeout duration in seconds. + """ + await asyncio.sleep(timeout_seconds) + + if confirmation_id in self._pending: + future = self._pending.pop(confirmation_id) + if not future.done(): + future.set_result(False) # Default to deny on timeout + + def resolve_confirmation(self, confirmation_id: str, confirmed: bool) -> bool: + """ + Resolve a pending confirmation with the user's decision. + + Args: + confirmation_id: ID of the confirmation request. + confirmed: Whether the user confirmed (True) or denied (False). + + Returns: + True if confirmation was found and resolved, False otherwise. + """ + future = self._pending.pop(confirmation_id, None) + if future and not future.done(): + future.set_result(confirmed) + return True + return False diff --git a/packages/ragbits-agents/src/ragbits/agents/tool.py b/packages/ragbits-agents/src/ragbits/agents/tool.py index c8e279f29..ec6e387df 100644 --- a/packages/ragbits-agents/src/ragbits/agents/tool.py +++ b/packages/ragbits-agents/src/ragbits/agents/tool.py @@ -51,14 +51,17 @@ class Tool: context_var_name: str | None = None """The name of the context variable that this tool accepts.""" id: str | None = None + requires_confirmation: bool = False + """Whether this tool requires user confirmation before execution.""" @classmethod - def from_callable(cls, callable: Callable) -> Self: + def from_callable(cls, callable: Callable, requires_confirmation: bool = False) -> Self: """ Create a Tool instance from a callable function. Args: callable: The function to convert into a Tool + requires_confirmation: Whether this tool requires user confirmation before execution Returns: A new Tool instance representing the callable function. @@ -71,6 +74,7 @@ def from_callable(cls, callable: Callable) -> Self: parameters=schema["function"]["parameters"], on_tool_call=callable, context_var_name=get_context_variable_name(callable), + requires_confirmation=requires_confirmation, ) def to_function_schema(self) -> dict[str, Any]: diff --git a/packages/ragbits-chat/src/ragbits/chat/api.py b/packages/ragbits-chat/src/ragbits/chat/api.py index 13f938d45..8e7ff0a07 100644 --- a/packages/ragbits-chat/src/ragbits/chat/api.py +++ b/packages/ragbits-chat/src/ragbits/chat/api.py @@ -17,6 +17,7 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel +from ragbits.agents.confirmation import ConfirmationManager from ragbits.chat.auth import AuthenticationBackend, User from ragbits.chat.auth.types import LoginRequest, LoginResponse, LogoutRequest from ragbits.chat.interface import ChatInterface @@ -48,6 +49,13 @@ CHUNK_SIZE = 102400 # ~100KB bytes base64 chunks for ultra-safe JSON parsing and SSE transmission +class ConfirmationResponse(BaseModel): + """Response to a confirmation request from the user.""" + + confirmation_id: str + confirmed: bool + + class RagbitsAPI: """ RagbitsAPI class for running API with Demo UI for testing purposes @@ -81,6 +89,7 @@ def __init__( self.auth_backend = self._load_auth_backend(auth_backend) self.security = HTTPBearer(auto_error=False) if auth_backend else None self.theme_path = Path(theme_path) if theme_path else None + self.confirmation_manager = ConfirmationManager() @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: @@ -211,6 +220,13 @@ async def theme() -> PlainTextResponse: logger.error(f"Error serving theme: {e}") raise HTTPException(status_code=500, detail="Error loading theme") from e + # Confirmation endpoint - for handling user confirmations of tool calls + @self.app.post("/api/confirm", response_class=JSONResponse) + async def confirm(response: ConfirmationResponse) -> JSONResponse: + """Handle user confirmation response for tool calls.""" + success = self.confirmation_manager.resolve_confirmation(response.confirmation_id, response.confirmed) + return JSONResponse(content={"status": success}) + @self.app.get("/{full_path:path}", response_class=HTMLResponse) async def root() -> HTMLResponse: index_file = self.dist_dir / "index.html" @@ -240,8 +256,8 @@ async def _validate_authentication(self, credentials: HTTPAuthorizationCredentia return auth_result.user - @staticmethod def _prepare_chat_context( + self, request: ChatMessageRequest, authenticated_user: User | None, credentials: HTTPAuthorizationCredentials | None, @@ -249,6 +265,9 @@ def _prepare_chat_context( """Prepare and validate chat context from request.""" chat_context = ChatContext(**request.context) + # Add confirmation manager to context + chat_context.confirmation_manager = self.confirmation_manager # type: ignore[attr-defined] + # Add session_id to context if authenticated if authenticated_user and credentials: chat_context.session_id = credentials.credentials @@ -302,7 +321,7 @@ async def _handle_chat_message( raise HTTPException(status_code=500, detail="Chat implementation is not initialized") # Prepare chat context - chat_context = RagbitsAPI._prepare_chat_context(request, authenticated_user, credentials) + chat_context = self._prepare_chat_context(request, authenticated_user, credentials) # Get the response generator from the chat interface response_generator = self.chat_interface.chat( diff --git a/packages/ragbits-chat/src/ragbits/chat/interface/types.py b/packages/ragbits-chat/src/ragbits/chat/interface/types.py index 6c465d5af..52a0d5275 100644 --- a/packages/ragbits-chat/src/ragbits/chat/interface/types.py +++ b/packages/ragbits-chat/src/ragbits/chat/interface/types.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field +from ragbits.agents.confirmation import ConfirmationRequest from ragbits.agents.tools.todo import Task from ragbits.chat.auth.types import User from ragbits.chat.interface.forms import UserSettings @@ -125,6 +126,7 @@ class ChatResponseType(str, Enum): CLEAR_MESSAGE = "clear_message" USAGE = "usage" TODO_ITEM = "todo_item" + CONFIRMATION_REQUEST = "confirmation_request" class ChatContext(BaseModel): @@ -153,6 +155,7 @@ class ChatResponse(BaseModel): | ChunkedContent | None | Task + | ConfirmationRequest ) def as_text(self) -> str | None: @@ -235,6 +238,12 @@ def as_task(self) -> Task | None: """ return cast(Task, self.content) if self.type == ChatResponseType.TODO_ITEM else None + def as_confirmation_request(self) -> ConfirmationRequest | None: + """ + Return the content as ConfirmationRequest if this is a confirmation request, else None. + """ + return cast(ConfirmationRequest, self.content) if self.type == ChatResponseType.CONFIRMATION_REQUEST else None + def as_conversation_summary(self) -> str | None: """ Return the content as string if this is an conversation summary response, else None diff --git a/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py b/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py index b94bfe4bc..91347a1b8 100644 --- a/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py +++ b/packages/ragbits-chat/src/ragbits/chat/providers/model_provider.py @@ -10,6 +10,7 @@ from pydantic import BaseModel +from ragbits.agents.confirmation import ConfirmationRequest from ragbits.agents.tools.todo import Task, TaskStatus from ragbits.chat.interface.types import AuthType @@ -87,6 +88,7 @@ def get_models(self) -> dict[str, type[BaseModel | Enum]]: # Core data models "ChatContext": ChatContext, "ChunkedContent": ChunkedContent, + "ConfirmationRequest": ConfirmationRequest, "LiveUpdate": LiveUpdate, "LiveUpdateContent": LiveUpdateContent, "Message": Message, @@ -144,6 +146,7 @@ def get_categories(self) -> dict[str, list[str]]: "core_data": [ "ChatContext", "ChunkedContent", + "ConfirmationRequest", "LiveUpdate", "LiveUpdateContent", "Message", diff --git a/typescript/@ragbits/api-client/src/autogen.types.ts b/typescript/@ragbits/api-client/src/autogen.types.ts index ebeeb03f8..da10ce6f0 100644 --- a/typescript/@ragbits/api-client/src/autogen.types.ts +++ b/typescript/@ragbits/api-client/src/autogen.types.ts @@ -25,6 +25,7 @@ export const ChatResponseType = { ClearMessage: 'clear_message', Usage: 'usage', TodoItem: 'todo_item', + ConfirmationRequest: 'confirmation_request', } as const export type ChatResponseType = TypeFrom @@ -109,6 +110,19 @@ export interface ChunkedContent { data: string } +/** + * Represents a tool confirmation request sent to the user. + */ +export interface ConfirmationRequest { + confirmation_id: string + tool_name: string + tool_description: string + arguments: { + [k: string]: unknown + } + timeout_seconds: number +} + /** * Represents an live update performed by an agent. */ @@ -509,6 +523,11 @@ export interface TodoItemChatResonse { content: Task } +export interface ConfirmationRequestChatResponse { + type: 'confirmation_request' + content: ConfirmationRequest +} + export interface ConversationSummaryResponse { type: 'conversation_summary' content: string @@ -534,4 +553,5 @@ export type ChatResponse = | ClearMessageResponse | MessageUsageChatResponse | TodoItemChatResonse + | ConfirmationRequestChatResponse | ConversationSummaryResponse diff --git a/typescript/@ragbits/api-client/src/types.ts b/typescript/@ragbits/api-client/src/types.ts index ea084aca3..e2266523b 100644 --- a/typescript/@ragbits/api-client/src/types.ts +++ b/typescript/@ragbits/api-client/src/types.ts @@ -49,6 +49,7 @@ export interface BaseApiEndpoints { '/api/auth/login': EndpointDefinition '/api/auth/logout': EndpointDefinition '/api/theme': EndpointDefinition + } /** diff --git a/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx b/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx index 2a802cb6e..ed680e5fa 100644 --- a/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx +++ b/typescript/ui/src/core/components/ChatMessage/ChatMessage.tsx @@ -7,6 +7,7 @@ import ImageGallery from "./ImageGallery.tsx"; import MessageReferences from "./MessageReferences.tsx"; import MessageActions from "./MessageActions.tsx"; import LoadingIndicator from "./LoadingIndicator.tsx"; +import ConfirmationDialog from "./ConfirmationDialog.tsx"; import { useConversationProperty, useMessage, @@ -14,6 +15,7 @@ import { import { MessageRole } from "@ragbits/api-client"; import TodoList from "../TodoList.tsx"; import { AnimatePresence, motion } from "framer-motion"; +import { useHistoryStore } from "../../stores/HistoryStore/useHistoryStore.ts"; type ChatMessageProps = { classNames?: { @@ -30,13 +32,23 @@ const ChatMessage = forwardRef( const lastMessageId = useConversationProperty((s) => s.lastMessageId); const isHistoryLoading = useConversationProperty((s) => s.isLoading); const message = useMessage(messageId); + const conversation = useHistoryStore((s) => + s.primitives.getCurrentConversation(), + ); if (!message) { throw new Error("Tried to render non-existent message"); } - const { serverId, content, role, references, liveUpdates, images } = - message; + const { + serverId, + content, + role, + references, + liveUpdates, + images, + confirmationRequest, + } = message; const rightAlign = role === MessageRole.User; const isLoading = isHistoryLoading && @@ -52,6 +64,30 @@ const ChatMessage = forwardRef( const showMessageReferences = !isLoading && references && references.length > 0; const showLiveUpdates = liveUpdates; + const showConfirmation = !!confirmationRequest; + + const handleConfirmation = async (confirmed: boolean) => { + if (!confirmationRequest) return; + + try { + await fetch("/api/confirm", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + confirmation_id: confirmationRequest.confirmation_id, + confirmed, + }), + }); + + // Clear the confirmation request from the message after responding + // Note: Direct mutation here - zustand will see the change on next render + conversation.history[messageId].confirmationRequest = undefined; + } catch (error) { + console.error("Failed to send confirmation:", error); + } + }; return (
( )} + {showConfirmation && ( + handleConfirmation(true)} + onSkip={() => handleConfirmation(false)} + /> + )} void; + onSkip: () => void; +}; + +const ConfirmationDialog = ({ + confirmationRequest, + onConfirm, + onSkip, +}: ConfirmationDialogProps) => { + const [isResponded, setIsResponded] = useState(false); + + const handleConfirm = () => { + setIsResponded(true); + onConfirm(); + }; + + const handleSkip = () => { + setIsResponded(true); + onSkip(); + }; + + if (isResponded) { + return null; + } + + return ( + + + + +

+ āš ļø Confirmation Required +

+

+ {confirmationRequest.tool_description} +

+
+ +
+

+ Tool: {confirmationRequest.tool_name} +

+ {Object.keys(confirmationRequest.arguments).length > 0 && ( +
+

Arguments:

+
+                    {JSON.stringify(confirmationRequest.arguments, null, 2)}
+                  
+
+ )} +
+
+ + + + +
+
+
+ ); +}; + +export default ConfirmationDialog; diff --git a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts index b52dcb37f..24a7711bb 100644 --- a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts +++ b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/eventHandlerRegistry.ts @@ -9,6 +9,7 @@ import { } from "./nonMessageHandlers"; import { handleClearMessage, + handleConfirmationRequest, handleImage, handleLiveUpdate, handleMessageId, @@ -109,3 +110,6 @@ ChatHandlerRegistry.register(ChatResponseType.TodoItem, { ChatHandlerRegistry.register(ChatResponseType.ConversationSummary, { handle: handleConversationSummary, }); +ChatHandlerRegistry.register(ChatResponseType.ConfirmationRequest, { + handle: handleConfirmationRequest, +}); diff --git a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts index d03437a1b..55c20c72f 100644 --- a/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts +++ b/typescript/ui/src/core/stores/HistoryStore/eventHandlers/messageHandlers.ts @@ -1,5 +1,6 @@ import { ClearMessageResponse, + ConfirmationRequestChatResponse, ImageChatResponse, LiveUpdateChatResponse, LiveUpdateType, @@ -117,3 +118,10 @@ export const handleTodoItem: PrimaryHandler = ( message.tasks = newTasks; }; + +export const handleConfirmationRequest: PrimaryHandler< + ConfirmationRequestChatResponse +> = (response, draft, ctx) => { + const message = draft.history[ctx.messageId]; + message.confirmationRequest = response.content; +}; diff --git a/typescript/ui/src/types/history.ts b/typescript/ui/src/types/history.ts index 249264577..ae885287d 100644 --- a/typescript/ui/src/types/history.ts +++ b/typescript/ui/src/types/history.ts @@ -1,5 +1,6 @@ import { ChatResponse, + ConfirmationRequest, LiveUpdate, MessageRole, Reference, @@ -26,6 +27,7 @@ export interface ChatMessage { images?: Record; usage?: Record; tasks?: Task[]; + confirmationRequest?: ConfirmationRequest; } export interface Conversation { From bf55eb9eea5642f7e8c2f75efb49bbdb9cbc0839 Mon Sep 17 00:00:00 2001 From: jakubduda-dsai Date: Tue, 4 Nov 2025 07:42:33 +0100 Subject: [PATCH 2/2] confirmaion improvement --- packages/ragbits-agents/src/ragbits/agents/confirmation.py | 3 +++ .../ui/src/core/components/ChatMessage/ConfirmationDialog.tsx | 3 +++ 2 files changed, 6 insertions(+) diff --git a/packages/ragbits-agents/src/ragbits/agents/confirmation.py b/packages/ragbits-agents/src/ragbits/agents/confirmation.py index 467dd0029..43b0e4df2 100644 --- a/packages/ragbits-agents/src/ragbits/agents/confirmation.py +++ b/packages/ragbits-agents/src/ragbits/agents/confirmation.py @@ -106,3 +106,6 @@ def resolve_confirmation(self, confirmation_id: str, confirmed: bool) -> bool: future.set_result(confirmed) return True return False + + + diff --git a/typescript/ui/src/core/components/ChatMessage/ConfirmationDialog.tsx b/typescript/ui/src/core/components/ChatMessage/ConfirmationDialog.tsx index 9318bf330..aa3f27238 100644 --- a/typescript/ui/src/core/components/ChatMessage/ConfirmationDialog.tsx +++ b/typescript/ui/src/core/components/ChatMessage/ConfirmationDialog.tsx @@ -87,3 +87,6 @@ const ConfirmationDialog = ({ }; export default ConfirmationDialog; + + +