From 9644762d388ae4f4f277dfee98b891389eb2fe42 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 26 Jun 2025 15:55:15 +0200 Subject: [PATCH 1/6] Refactor to increase code reuse, simplify --- .../tools/mcp/mcp_tool.py | 298 +++++++++++++----- .../tools/mcp/mcp_toolset.py | 124 ++------ 2 files changed, 244 insertions(+), 178 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index 71ca5f876..77e70738c 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -632,6 +632,207 @@ def create_client(self) -> MCPClient: return StdioClient(self.command, self.args, self.env) +def _extract_error_message(exception: Exception) -> str: + """ + Extracts meaningful error message from various exception types. + Handles ExceptionGroup, empty messages, etc. + """ + error_message = str(exception) + # Handle ExceptionGroup to extract more useful error messages + if isinstance(exception, ExceptionGroup): + if exception.exceptions: + first_exception = exception.exceptions[0] + error_message = first_exception.message if hasattr(first_exception, "message") else str(first_exception) + + # Ensure we always have a meaningful error message + if not error_message or error_message.strip() == "": + # Provide platform-independent fallback message for connection errors + error_message = "Connection failed to MCP server" + + return error_message + + +def _create_stdio_connection_error_message(server_info: StdioServerInfo, operation: str, context: str) -> str: + """ + Creates stdio-specific error messages with command troubleshooting. + """ + base_message = f"Failed to {operation} {context} via stdio" + + # Build command string for diagnostics + args_str = " ".join(server_info.args) if server_info.args else "" + cmd = f"{server_info.command}{' ' + args_str if args_str else ''}" + + checks = [f"1. The command and arguments are correct (attempted: {cmd})"] + + return f"{base_message}. Please check if:\n" + "\n".join(checks) + + +def _create_http_connection_error_message( + server_info: SSEServerInfo | StreamableHttpServerInfo, exception: Exception, operation: str, context: str +) -> str: + """ + Creates HTTP-specific error messages with troubleshooting guidance. + """ + # Determine transport type + transport_name = "SSE" if isinstance(server_info, SSEServerInfo) else "streamable HTTP" + server_url = server_info.url + + base_message = f"Failed to {operation} {context} via {transport_name}" + + # Standard troubleshooting steps + checks = [ + f"1. The server URL is correct (attempted: {server_url})", + "2. The server is running and accessible", + "3. Authentication token is correct (if required)", + ] + + # Check if exception indicates a network connection error + import httpx + + has_connect_error = isinstance(exception, httpx.ConnectError) or ( + isinstance(exception, ExceptionGroup) + and any(isinstance(exc, httpx.ConnectError) for exc in exception.exceptions) + ) + + # Add network-specific guidance for connection errors + if has_connect_error: + from urllib.parse import urlparse + + # Use urlparse to reliably get scheme, hostname, and port + parsed_url = urlparse(server_url) + port_str = "" + if parsed_url.port: + port_str = str(parsed_url.port) + elif parsed_url.scheme == "http": + port_str = "80 (default)" + elif parsed_url.scheme == "https": + port_str = "443 (default)" + else: + port_str = "unknown (scheme not http/https or missing)" + + # Ensure hostname is handled correctly (it might be None) + hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "" + + checks[1] = f"2. The address '{hostname_str}' and port '{port_str}' are correct" + checks.append("4. There are no firewall or network connectivity issues") + + return f"{base_message}. Please check if:\n" + "\n".join(checks) + + +def _create_connection_error_message( + server_info: MCPServerInfo, exception: Exception, operation: str, context: str = "" +) -> str: + """ + Creates contextual error messages based on server type and failure details. + This replaces the duplicate error handling blocks in both classes. + """ + + # Generate server-type specific guidance + if isinstance(server_info, SSEServerInfo | StreamableHttpServerInfo): + return _create_http_connection_error_message(server_info, exception, operation, context) + elif isinstance(server_info, StdioServerInfo): + return _create_stdio_connection_error_message(server_info, operation, context) + else: + error_message = _extract_error_message(exception) + return f"Failed to {operation} {context}: {error_message}" + + +class MCPConnectionManager: + """ + Utility class that encapsulates common MCP connection logic shared between + MCPTool and MCPToolset. + """ + + def __init__(self, server_info: MCPServerInfo, connection_timeout: float): + self.server_info = server_info + self.connection_timeout = connection_timeout + self._client = None + self._worker = None + + def connect_and_discover_tools(self) -> list[Tool]: + """ + Establishes connection and returns available tools. + This replaces the duplicate connection logic in both classes. + """ + try: + # Create the client and spin up a worker so open/close happen in the + # same coroutine, avoiding AnyIO cancel-scope issues. + self._client = self.server_info.create_client() + self._worker = _MCPClientSessionManager(self._client, timeout=self.connection_timeout) + return self._worker.tools() + except Exception: + # Handle cleanup internally + self.close() + raise + + def validate_requested_tools(self, requested_tool_names: list[str], available_tools: list[Tool]) -> None: + """ + Validates that requested tools exist on the server. + Shared validation logic between both classes. + """ + available_tool_names = {tool.name for tool in available_tools} + missing_tools = set(requested_tool_names) - available_tool_names + if missing_tools: + message = ( + f"The following tools were not found: {', '.join(missing_tools)}. " + f"Available tools: {', '.join(available_tool_names)}" + ) + raise MCPToolNotFoundError( + message=message, tool_name=next(iter(missing_tools)), available_tools=list(available_tool_names) + ) + + def create_tool_invoke_function(self, tool_name: str, invocation_timeout: float): + """ + Creates the invoke function that both classes use. + MCPTool uses this directly, MCPToolset uses this in its closure factory. + """ + + def invoke_tool(**kwargs) -> str: + """Unified invoke logic - no more duplication""" + logger.debug(f"Invoking tool '{tool_name}' with args: {kwargs}") + try: + + async def invoke(): + logger.debug(f"Inside invoke coroutine for '{tool_name}'") + result = await asyncio.wait_for( + self._client.call_tool(tool_name, kwargs), timeout=invocation_timeout + ) + logger.debug(f"Invoke successful for '{tool_name}'") + return result + + logger.debug(f"About to run invoke for '{tool_name}'") + result = AsyncExecutor.get_instance().run(invoke(), timeout=invocation_timeout) + logger.debug(f"Invoke complete for '{tool_name}', result type: {type(result)}") + return result + except (MCPError, TimeoutError) as e: + logger.debug(f"Known error during invoke of '{tool_name}': {e!s}") + # Pass through known errors + raise + except Exception as e: + # Wrap other errors + logger.debug(f"Unknown error during invoke of '{tool_name}': {e!s}") + message = f"Failed to invoke tool '{tool_name}' with args: {kwargs} , got error: {e!s}" + raise MCPInvocationError(message, tool_name, kwargs) from e + + return invoke_tool + + def get_client(self): + """Allow direct access to client for MCPTool's async method access""" + return self._client + + def close(self): + """Shared cleanup logic""" + if hasattr(self, "_worker") and self._worker: + try: + self._worker.stop() + except Exception as e: + logger.debug(f"Error during worker stop: {e!s}") + + # Clear references + self._worker = None + self._client = None + + class MCPTool(Tool): """ A Tool that represents a single tool from an MCP server. @@ -706,41 +907,35 @@ def __init__( logger.debug(f"TOOL: Initializing MCPTool '{name}'") try: - # Create client and spin up a long-lived worker that keeps the - # connect/close lifecycle inside one coroutine. - self._client = server_info.create_client() - logger.debug(f"TOOL: Created client for MCPTool '{name}'") + # Use shared connection logic + self._connection_manager = MCPConnectionManager(server_info, connection_timeout) + available_tools = self._connection_manager.connect_and_discover_tools() - # The worker starts immediately and blocks here until the connection - # is established (or fails), returning the tool list. - self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout) - - tools = self._worker.tools() # Handle no tools case - if not tools: + if not available_tools: logger.debug(f"TOOL: No tools found for '{name}'") message = "No tools available on server" raise MCPToolNotFoundError(message, tool_name=name) + # Validate that the requested tool exists + self._connection_manager.validate_requested_tools([name], available_tools) + # Find the specified tool - tool_dict = {t.name: t for t in tools} + tool_dict = {t.name: t for t in available_tools} logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}") + tool_info = tool_dict[name] # Safe to use direct access since validation passed - tool_info = tool_dict.get(name) + logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class") - if not tool_info: - available = list(tool_dict.keys()) - logger.debug(f"TOOL: Tool '{name}' not found in available tools") - message = f"Tool '{name}' not found on server. Available tools: {', '.join(available)}" - raise MCPToolNotFoundError(message, tool_name=name, available_tools=available) + # Create shared invoke function + invoke_func = self._connection_manager.create_tool_invoke_function(name, invocation_timeout) - logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class") # Initialize the parent class super().__init__( name=name, description=description or tool_info.description, parameters=tool_info.inputSchema, - function=self._invoke_tool, + function=invoke_func, ) logger.debug(f"TOOL: Initialization complete for '{name}'") @@ -749,55 +944,11 @@ def __init__( # fail because of an MCPToolNotFoundError self.close() - # Extract more detailed error information from TaskGroup/ExceptionGroup exceptions - error_message = str(e) - # Handle ExceptionGroup to extract more useful error messages - if isinstance(e, ExceptionGroup): - if e.exceptions: - first_exception = e.exceptions[0] - error_message = ( - first_exception.message if hasattr(first_exception, "message") else str(first_exception) - ) - - # Ensure we always have a meaningful error message - if not error_message or error_message.strip() == "": - # Provide platform-independent fallback message for connection errors - error_message = f"Connection failed to MCP server (using {type(server_info).__name__})" - - message = f"Failed to initialize MCPTool '{name}': {error_message}" - raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e - - def _invoke_tool(self, **kwargs: Any) -> str: - """ - Synchronous tool invocation. - - :param kwargs: Arguments to pass to the tool - :returns: JSON string representation of the tool invocation result - """ - logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}") - try: - - async def invoke(): - logger.debug(f"TOOL: Inside invoke coroutine for '{self.name}'") - result = await asyncio.wait_for( - self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout - ) - logger.debug(f"TOOL: Invoke successful for '{self.name}'") - return result - - logger.debug(f"TOOL: About to run invoke for '{self.name}'") - result = AsyncExecutor.get_instance().run(invoke(), timeout=self._invocation_timeout) - logger.debug(f"TOOL: Invoke complete for '{self.name}', result type: {type(result)}") - return result - except (MCPError, TimeoutError) as e: - logger.debug(f"TOOL: Known error during invoke of '{self.name}': {e!s}") - # Pass through known errors - raise - except Exception as e: - # Wrap other errors - logger.debug(f"TOOL: Unknown error during invoke of '{self.name}': {e!s}") - message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}" - raise MCPInvocationError(message, self.name, kwargs) from e + # Use shared error handling + error_message = _create_connection_error_message( + server_info=server_info, exception=e, operation="initialize", context=f"MCPTool '{name}'" + ) + raise MCPConnectionError(message=error_message, server_info=server_info, operation="initialize") from e async def ainvoke(self, **kwargs: Any) -> str: """ @@ -809,7 +960,8 @@ async def ainvoke(self, **kwargs: Any) -> str: :raises TimeoutError: If the operation times out """ try: - return await asyncio.wait_for(self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) + client = self._connection_manager.get_client() + return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) except asyncio.TimeoutError as e: message = f"Tool invocation timed out after {self._invocation_timeout} seconds" raise TimeoutError(message) from e @@ -881,13 +1033,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool": def close(self): """Close the tool synchronously.""" - if hasattr(self, "_client") and self._client: + if hasattr(self, "_connection_manager") and self._connection_manager: try: - # Tell the background worker to shut down gracefully. - if hasattr(self, "_worker") and self._worker: - self._worker.stop() + self._connection_manager.close() except Exception as e: - logger.debug(f"TOOL: Error during synchronous worker stop: {e!s}") + logger.debug(f"TOOL: Error during connection manager close: {e!s}") def __del__(self): """Cleanup resources when the tool is garbage collected.""" diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 39fcf3d98..38bed4462 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -2,26 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Callable from typing import Any -from urllib.parse import urlparse -import httpx -from exceptiongroup import ExceptionGroup from haystack import logging from haystack.core.serialization import generate_qualified_class_name, import_class_by_name from haystack.tools import Tool, Toolset from .mcp_tool import ( - AsyncExecutor, - MCPClient, MCPConnectionError, + MCPConnectionManager, MCPServerInfo, - MCPToolNotFoundError, - SSEServerInfo, - StdioServerInfo, - StreamableHttpServerInfo, - _MCPClientSessionManager, + _create_connection_error_message, ) logger = logging.getLogger(__name__) @@ -126,46 +117,17 @@ def __init__( # Connect and load tools try: - # Create the client and spin up a worker so open/close happen in the - # same coroutine, avoiding AnyIO cancel-scope issues. - client = self.server_info.create_client() - self._worker = _MCPClientSessionManager(client, timeout=self.connection_timeout) - - tools = self._worker.tools() + # Use shared connection logic + self._connection_manager = MCPConnectionManager(server_info, connection_timeout) + available_tools = self._connection_manager.connect_and_discover_tools() # If tool_names is provided, validate that all requested tools exist if self.tool_names: - available_tools = {tool.name for tool in tools} - missing_tools = set(self.tool_names) - available_tools - if missing_tools: - message = ( - f"The following tools were not found: {', '.join(missing_tools)}. " - f"Available tools: {', '.join(available_tools)}" - ) - raise MCPToolNotFoundError( - message=message, tool_name=next(iter(missing_tools)), available_tools=list(available_tools) - ) - - # This is a factory that creates the invocation function for the Tool - def create_invoke_tool( - owner_toolset: "MCPToolset", - mcp_client: MCPClient, - tool_name: str, - tool_timeout: float, - ) -> Callable[..., Any]: - """Return a closure that keeps a strong reference to *owner_toolset* alive.""" - - def invoke_tool(**kwargs) -> Any: - _ = owner_toolset # strong reference so GC can't collect the toolset too early - return AsyncExecutor.get_instance().run( - mcp_client.call_tool(tool_name, kwargs), timeout=tool_timeout - ) + self._connection_manager.validate_requested_tools(self.tool_names, available_tools) - return invoke_tool - - # Create Tool instances not MCPTool for each available tool + # Create Tool instances using shared invoke function creation haystack_tools = [] - for tool_info in tools: + for tool_info in available_tools: # Skip tools not in the tool_names list if tool_names is provided if self.tool_names is not None and tool_info.name not in self.tool_names: logger.debug( @@ -173,12 +135,14 @@ def invoke_tool(**kwargs) -> Any: ) continue - # Use the helper function to create the invoke_tool function + # Use shared function creation instead of local closure + invoke_func = self._connection_manager.create_tool_invoke_function(tool_info.name, invocation_timeout) + tool = Tool( name=tool_info.name, description=tool_info.description, parameters=tool_info.inputSchema, - function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout), + function=invoke_func, ) haystack_tools.append(tool) @@ -190,59 +154,11 @@ def invoke_tool(**kwargs) -> Any: # fail because of an MCPToolNotFoundError self.close() - # Create informative error message for SSE connection errors - # Common error handling for HTTP-based transports - if isinstance(self.server_info, (SSEServerInfo | StreamableHttpServerInfo)): - # Determine transport type for messages - transport_name = "SSE" if isinstance(self.server_info, SSEServerInfo) else "streamable HTTP" - server_url = self.server_info.url - - base_message = f"Failed to connect to MCP server via {transport_name}" - checks = [ - f"1. The server URL is correct (attempted: {server_url})", - "2. The server is running and accessible", - "3. Authentication token is correct (if required)", - ] - - # Add specific connection error details for network issues - has_connect_error = isinstance(e, httpx.ConnectError) or ( - isinstance(e, ExceptionGroup) and any(isinstance(exc, httpx.ConnectError) for exc in e.exceptions) - ) - - if has_connect_error: - # Use urlparse to reliably get scheme, hostname, and port - parsed_url = urlparse(server_url) - port_str = "" - if parsed_url.port: - port_str = str(parsed_url.port) - elif parsed_url.scheme == "http": - port_str = "80 (default)" - elif parsed_url.scheme == "https": - port_str = "443 (default)" - else: - port_str = "unknown (scheme not http/https or missing)" - - # Ensure hostname is handled correctly (it might be None) - hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "" - - # Replace generic accessible message with specific network details - checks[1] = f"2. The address '{hostname_str}' and port '{port_str}' are correct" - checks.append("4. There are no firewall or network connectivity issues") - - message = f"{base_message}. Please check if:\n" + "\n".join(checks) - - # and for stdio connection errors - elif isinstance(self.server_info, StdioServerInfo): # stdio connection - base_message = "Failed to start MCP server process" - stdio_info = self.server_info - args_str = " ".join(stdio_info.args) if stdio_info.args else "" - cmd = f"{stdio_info.command}{' ' + args_str if args_str else ''}" - checks = [f"1. The command and arguments are correct (attempted: {cmd})"] - message = f"{base_message}. Please check if:\n" + "\n".join(checks) - else: - message = f"Unsupported server info type: {type(self.server_info)}" - - raise MCPConnectionError(message=message, server_info=self.server_info, operation="initialize") from e + # Use shared error handling + error_message = _create_connection_error_message( + server_info=self.server_info, exception=e, operation="initialize", context="MCPToolset" + ) + raise MCPConnectionError(message=error_message, server_info=self.server_info, operation="initialize") from e def to_dict(self) -> dict[str, Any]: """ @@ -285,11 +201,11 @@ def from_dict(cls, data: dict[str, Any]) -> "MCPToolset": def close(self): """Close the underlying MCP client safely.""" - if hasattr(self, "_worker") and self._worker: + if hasattr(self, "_connection_manager") and self._connection_manager: try: - self._worker.stop() + self._connection_manager.close() except Exception as e: - logger.debug(f"TOOLSET: error during worker stop: {e!s}") + logger.debug(f"TOOLSET: error during connection manager close: {e!s}") def __del__(self): self.close() From 62b258750f64929689091fb3bef4269fb78532f1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 26 Jun 2025 16:09:43 +0200 Subject: [PATCH 2/6] Linting --- .../src/haystack_integrations/tools/mcp/mcp_tool.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index 77e70738c..ec4f28205 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -746,8 +746,8 @@ class MCPConnectionManager: def __init__(self, server_info: MCPServerInfo, connection_timeout: float): self.server_info = server_info self.connection_timeout = connection_timeout - self._client = None - self._worker = None + self._client: MCPClient | None = None + self._worker: "_MCPClientSessionManager | None" = None def connect_and_discover_tools(self) -> list[Tool]: """ @@ -794,6 +794,9 @@ def invoke_tool(**kwargs) -> str: async def invoke(): logger.debug(f"Inside invoke coroutine for '{tool_name}'") + if self._client is None: + message = "MCP client is not connected" + raise MCPConnectionError(message=message, operation="invoke") result = await asyncio.wait_for( self._client.call_tool(tool_name, kwargs), timeout=invocation_timeout ) @@ -816,7 +819,7 @@ async def invoke(): return invoke_tool - def get_client(self): + def get_client(self) -> MCPClient | None: """Allow direct access to client for MCPTool's async method access""" return self._client @@ -961,6 +964,9 @@ async def ainvoke(self, **kwargs: Any) -> str: """ try: client = self._connection_manager.get_client() + if client is None: + message = "MCP client is not connected" + raise MCPConnectionError(message=message, operation="ainvoke") return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout) except asyncio.TimeoutError as e: message = f"Tool invocation timed out after {self._invocation_timeout} seconds" From 3710060419e89c791fea28037298354e73c5b7fe Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 26 Jun 2025 16:21:15 +0200 Subject: [PATCH 3/6] Ruffing --- .../mcp/src/haystack_integrations/tools/mcp/mcp_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index ec4f28205..c799e84a7 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -747,7 +747,7 @@ def __init__(self, server_info: MCPServerInfo, connection_timeout: float): self.server_info = server_info self.connection_timeout = connection_timeout self._client: MCPClient | None = None - self._worker: "_MCPClientSessionManager | None" = None + self._worker: _MCPClientSessionManager | None = None def connect_and_discover_tools(self) -> list[Tool]: """ From d2435bf4f5fc8ad16e07d455eea17195007aedb9 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 26 Jun 2025 16:36:04 +0200 Subject: [PATCH 4/6] Fix imports --- .../mcp/src/haystack_integrations/tools/mcp/mcp_tool.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index c799e84a7..4874283c6 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -13,7 +13,9 @@ from dataclasses import dataclass, fields from datetime import timedelta from typing import Any, cast +from urllib.parse import urlparse +import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from exceptiongroup import ExceptionGroup from haystack import logging @@ -687,8 +689,6 @@ def _create_http_connection_error_message( ] # Check if exception indicates a network connection error - import httpx - has_connect_error = isinstance(exception, httpx.ConnectError) or ( isinstance(exception, ExceptionGroup) and any(isinstance(exc, httpx.ConnectError) for exc in exception.exceptions) @@ -696,8 +696,6 @@ def _create_http_connection_error_message( # Add network-specific guidance for connection errors if has_connect_error: - from urllib.parse import urlparse - # Use urlparse to reliably get scheme, hostname, and port parsed_url = urlparse(server_url) port_str = "" From 8489ac47c6dd2ff8dcc60161ba0a0757bb115bcd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 2 Jul 2025 14:13:59 +0200 Subject: [PATCH 5/6] Improve pydocs --- .../tools/mcp/mcp_tool.py | 99 +++++++++++-------- .../tools/mcp/mcp_toolset.py | 2 +- 2 files changed, 60 insertions(+), 41 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index 4874283c6..01c977e78 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -49,7 +49,7 @@ def get_instance(cls) -> "AsyncExecutor": return cls._singleton_instance def __init__(self): - """Initialize a dedicated event loop""" + """Initialize a dedicated event loop.""" self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() self._thread: threading.Thread = threading.Thread(target=self._run_loop, daemon=True) self._started = threading.Event() @@ -59,7 +59,7 @@ def __init__(self): raise RuntimeError(message) def _run_loop(self): - """Run the event loop""" + """Run the event loop.""" asyncio.set_event_loop(self._loop) self._started.set() self._loop.run_forever() @@ -70,7 +70,7 @@ def run(self, coro: Coroutine[Any, Any, Any], timeout: float | None = None) -> A :param coro: Coroutine to execute :param timeout: Optional timeout in seconds - :return: Result of the coroutine + :returns: Result of the coroutine :raises TimeoutError: If execution exceeds timeout """ future = asyncio.run_coroutine_threadsafe(coro, self._loop) @@ -93,16 +93,11 @@ def run_background( self, coro_factory: Callable[[asyncio.Event], Coroutine[Any, Any, Any]], timeout: float | None = None ) -> tuple[concurrent.futures.Future[Any], asyncio.Event]: """ - Schedule `coro_factory` to run in the executor's event loop **without** blocking the - caller thread. + Schedule a coroutine factory to run in the executor's event loop without blocking the caller thread. - The factory receives an :class:`asyncio.Event` that can be used to cooperatively shut - the coroutine down. The method returns **both** the concurrent future (to observe - completion or failure) and the created *stop_event* so that callers can signal termination. - - :param coro_factory: A callable receiving the stop_event and returning the coroutine to execute. - :param timeout: Optional timeout while waiting for the stop_event to be created. - :returns: Tuple ``(future, stop_event)``. + :param coro_factory: A callable receiving the stop_event and returning the coroutine to execute + :param timeout: Optional timeout while waiting for the stop_event to be created + :returns: Tuple of (future, stop_event) """ # A promise that will be fulfilled from inside the coroutine_with_stop_event coroutine once the # stop_event is created *inside* the target event loop to ensure it is bound to the @@ -636,8 +631,10 @@ def create_client(self) -> MCPClient: def _extract_error_message(exception: Exception) -> str: """ - Extracts meaningful error message from various exception types. - Handles ExceptionGroup, empty messages, etc. + Extract meaningful error message from various exception types. + + :param exception: Exception to extract message from + :returns: Meaningful error message string """ error_message = str(exception) # Handle ExceptionGroup to extract more useful error messages @@ -656,7 +653,12 @@ def _extract_error_message(exception: Exception) -> str: def _create_stdio_connection_error_message(server_info: StdioServerInfo, operation: str, context: str) -> str: """ - Creates stdio-specific error messages with command troubleshooting. + Create stdio-specific error messages with command troubleshooting. + + :param server_info: Stdio server configuration + :param operation: Operation that failed + :param context: Context description + :returns: Formatted error message with troubleshooting guidance """ base_message = f"Failed to {operation} {context} via stdio" @@ -673,7 +675,13 @@ def _create_http_connection_error_message( server_info: SSEServerInfo | StreamableHttpServerInfo, exception: Exception, operation: str, context: str ) -> str: """ - Creates HTTP-specific error messages with troubleshooting guidance. + Create HTTP-specific error messages with troubleshooting guidance. + + :param server_info: HTTP server configuration + :param exception: Original exception that occurred + :param operation: Operation that failed + :param context: Context description + :returns: Formatted error message with troubleshooting guidance """ # Determine transport type transport_name = "SSE" if isinstance(server_info, SSEServerInfo) else "streamable HTTP" @@ -721,8 +729,13 @@ def _create_connection_error_message( server_info: MCPServerInfo, exception: Exception, operation: str, context: str = "" ) -> str: """ - Creates contextual error messages based on server type and failure details. - This replaces the duplicate error handling blocks in both classes. + Create contextual error messages based on server type and failure details. + + :param server_info: Server configuration + :param exception: Original exception that occurred + :param operation: Operation that failed + :param context: Context description + :returns: Formatted error message with troubleshooting guidance """ # Generate server-type specific guidance @@ -749,8 +762,10 @@ def __init__(self, server_info: MCPServerInfo, connection_timeout: float): def connect_and_discover_tools(self) -> list[Tool]: """ - Establishes connection and returns available tools. - This replaces the duplicate connection logic in both classes. + Establish connection and return available tools. + + :returns: List of available tools on the server + :raises MCPConnectionError: If connection fails """ try: # Create the client and spin up a worker so open/close happen in the @@ -765,8 +780,11 @@ def connect_and_discover_tools(self) -> list[Tool]: def validate_requested_tools(self, requested_tool_names: list[str], available_tools: list[Tool]) -> None: """ - Validates that requested tools exist on the server. - Shared validation logic between both classes. + Validate that requested tools exist on the server. + + :param requested_tool_names: List of tool names that were requested + :param available_tools: List of tools available on the server + :raises MCPToolNotFoundError: If any requested tools are not found """ available_tool_names = {tool.name for tool in available_tools} missing_tools = set(requested_tool_names) - available_tool_names @@ -781,8 +799,11 @@ def validate_requested_tools(self, requested_tool_names: list[str], available_to def create_tool_invoke_function(self, tool_name: str, invocation_timeout: float): """ - Creates the invoke function that both classes use. - MCPTool uses this directly, MCPToolset uses this in its closure factory. + Create the invoke function for a specific tool. + + :param tool_name: Name of the tool to create invoke function for + :param invocation_timeout: Timeout for tool invocations + :returns: Invoke function that can be called with tool arguments """ def invoke_tool(**kwargs) -> str: @@ -818,11 +839,15 @@ async def invoke(): return invoke_tool def get_client(self) -> MCPClient | None: - """Allow direct access to client for MCPTool's async method access""" + """ + Get direct access to the MCP client. + + :returns: The MCP client instance or None if not connected + """ return self._client def close(self): - """Shared cleanup logic""" + """Close the connection and clean up resources.""" if hasattr(self, "_worker") and self._worker: try: self._worker.stop() @@ -977,14 +1002,9 @@ async def ainvoke(self, **kwargs: Any) -> str: def to_dict(self) -> dict[str, Any]: """ - Serializes the MCPTool to a dictionary. - - The serialization preserves all information needed to recreate the tool, - including server connection parameters and timeout settings. Note that the - active connection is not maintained. + Serialize the MCPTool to a dictionary. - :returns: Dictionary with serialized data in the format: - {"type": fully_qualified_class_name, "data": {parameters}} + :returns: Dictionary with serialized data """ serialized = { "name": self.name, @@ -1001,15 +1021,10 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "Tool": """ - Deserializes the MCPTool from a dictionary. - - This method reconstructs an MCPTool instance from a serialized dictionary, - including recreating the server_info object. A new connection will be established - to the MCP server during initialization. + Deserialize the MCPTool from a dictionary. :param data: Dictionary containing serialized tool data :returns: A fully initialized MCPTool instance - :raises: Various exceptions if connection fails """ # Extract the tool parameters from the data dictionary inner_data = data["data"] @@ -1087,7 +1102,11 @@ def __init__(self, client: "MCPClient", *, timeout: float | None = None): raise def tools(self) -> list[Tool]: - """Return the tool list already collected during startup.""" + """ + Return the tool list already collected during startup. + + :returns: List of available tools + """ return self._tools_promise.result() diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py index 38bed4462..144da34b2 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py @@ -104,7 +104,7 @@ def __init__( :param server_info: Connection information for the MCP server :param tool_names: Optional list of tool names to include. If provided, only tools with - matching names will be added to the toolset. + matching names will be added to the toolset :param connection_timeout: Timeout in seconds for server connection :param invocation_timeout: Default timeout in seconds for tool invocations :raises MCPToolNotFoundError: If any of the specified tool names are not found on the server From ba437b63094e2c46f94e63bcb9c71e4fbd2e899d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 2 Jul 2025 15:00:18 +0200 Subject: [PATCH 6/6] Remove indirection --- .../tools/mcp/mcp_tool.py | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py index 01c977e78..f918bd61b 100644 --- a/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py +++ b/integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py @@ -629,28 +629,6 @@ def create_client(self) -> MCPClient: return StdioClient(self.command, self.args, self.env) -def _extract_error_message(exception: Exception) -> str: - """ - Extract meaningful error message from various exception types. - - :param exception: Exception to extract message from - :returns: Meaningful error message string - """ - error_message = str(exception) - # Handle ExceptionGroup to extract more useful error messages - if isinstance(exception, ExceptionGroup): - if exception.exceptions: - first_exception = exception.exceptions[0] - error_message = first_exception.message if hasattr(first_exception, "message") else str(first_exception) - - # Ensure we always have a meaningful error message - if not error_message or error_message.strip() == "": - # Provide platform-independent fallback message for connection errors - error_message = "Connection failed to MCP server" - - return error_message - - def _create_stdio_connection_error_message(server_info: StdioServerInfo, operation: str, context: str) -> str: """ Create stdio-specific error messages with command troubleshooting. @@ -744,7 +722,18 @@ def _create_connection_error_message( elif isinstance(server_info, StdioServerInfo): return _create_stdio_connection_error_message(server_info, operation, context) else: - error_message = _extract_error_message(exception) + error_message = str(exception) + # Handle ExceptionGroup to extract more useful error messages + if isinstance(exception, ExceptionGroup): + if exception.exceptions: + first_exception = exception.exceptions[0] + error_message = first_exception.message if hasattr(first_exception, "message") else str(first_exception) + + # Ensure we always have a meaningful error message + if not error_message or error_message.strip() == "": + # Provide platform-independent fallback message for connection errors + error_message = "Connection failed to MCP server" + return f"Failed to {operation} {context}: {error_message}"