Skip to content

feat: Refactor MCPTool and MCPToolsetto increase code reuse, simplify #2004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 228 additions & 74 deletions integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from dataclasses import dataclass, fields
from datetime import timedelta
from typing import Any, cast
from urllib.parse import urlparse

import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from exceptiongroup import ExceptionGroup
from haystack import logging
Expand Down Expand Up @@ -632,6 +634,206 @@ def create_client(self) -> MCPClient:
return StdioClient(self.command, self.args, self.env)


def _extract_error_message(exception: Exception) -> str:
"""
Extracts meaningful error message from various exception types.
Handles ExceptionGroup, empty messages, etc.
"""
error_message = str(exception)
# Handle ExceptionGroup to extract more useful error messages
if isinstance(exception, ExceptionGroup):
if exception.exceptions:
first_exception = exception.exceptions[0]
error_message = first_exception.message if hasattr(first_exception, "message") else str(first_exception)

# Ensure we always have a meaningful error message
if not error_message or error_message.strip() == "":
# Provide platform-independent fallback message for connection errors
error_message = "Connection failed to MCP server"

return error_message


def _create_stdio_connection_error_message(server_info: StdioServerInfo, operation: str, context: str) -> str:
"""
Creates stdio-specific error messages with command troubleshooting.
"""
base_message = f"Failed to {operation} {context} via stdio"

# Build command string for diagnostics
args_str = " ".join(server_info.args) if server_info.args else ""
cmd = f"{server_info.command}{' ' + args_str if args_str else ''}"

checks = [f"1. The command and arguments are correct (attempted: {cmd})"]

return f"{base_message}. Please check if:\n" + "\n".join(checks)


def _create_http_connection_error_message(
server_info: SSEServerInfo | StreamableHttpServerInfo, exception: Exception, operation: str, context: str
) -> str:
"""
Creates HTTP-specific error messages with troubleshooting guidance.
"""
# Determine transport type
transport_name = "SSE" if isinstance(server_info, SSEServerInfo) else "streamable HTTP"
server_url = server_info.url

base_message = f"Failed to {operation} {context} via {transport_name}"

# Standard troubleshooting steps
checks = [
f"1. The server URL is correct (attempted: {server_url})",
"2. The server is running and accessible",
"3. Authentication token is correct (if required)",
]

# Check if exception indicates a network connection error
has_connect_error = isinstance(exception, httpx.ConnectError) or (
isinstance(exception, ExceptionGroup)
and any(isinstance(exc, httpx.ConnectError) for exc in exception.exceptions)
)

# Add network-specific guidance for connection errors
if has_connect_error:
# Use urlparse to reliably get scheme, hostname, and port
parsed_url = urlparse(server_url)
port_str = ""
if parsed_url.port:
port_str = str(parsed_url.port)
elif parsed_url.scheme == "http":
port_str = "80 (default)"
elif parsed_url.scheme == "https":
port_str = "443 (default)"
else:
port_str = "unknown (scheme not http/https or missing)"

# Ensure hostname is handled correctly (it might be None)
hostname_str = str(parsed_url.hostname) if parsed_url.hostname else "<unknown>"

checks[1] = f"2. The address '{hostname_str}' and port '{port_str}' are correct"
checks.append("4. There are no firewall or network connectivity issues")

return f"{base_message}. Please check if:\n" + "\n".join(checks)


def _create_connection_error_message(
server_info: MCPServerInfo, exception: Exception, operation: str, context: str = ""
) -> str:
"""
Creates contextual error messages based on server type and failure details.
This replaces the duplicate error handling blocks in both classes.
"""

# Generate server-type specific guidance
if isinstance(server_info, SSEServerInfo | StreamableHttpServerInfo):
return _create_http_connection_error_message(server_info, exception, operation, context)
elif isinstance(server_info, StdioServerInfo):
return _create_stdio_connection_error_message(server_info, operation, context)
else:
error_message = _extract_error_message(exception)
return f"Failed to {operation} {context}: {error_message}"


class MCPConnectionManager:
"""
Utility class that encapsulates common MCP connection logic shared between
MCPTool and MCPToolset.
"""

def __init__(self, server_info: MCPServerInfo, connection_timeout: float):
self.server_info = server_info
self.connection_timeout = connection_timeout
self._client: MCPClient | None = None
self._worker: _MCPClientSessionManager | None = None

def connect_and_discover_tools(self) -> list[Tool]:
"""
Establishes connection and returns available tools.
This replaces the duplicate connection logic in both classes.
"""
try:
# Create the client and spin up a worker so open/close happen in the
# same coroutine, avoiding AnyIO cancel-scope issues.
self._client = self.server_info.create_client()
self._worker = _MCPClientSessionManager(self._client, timeout=self.connection_timeout)
return self._worker.tools()
except Exception:
# Handle cleanup internally
self.close()
raise

def validate_requested_tools(self, requested_tool_names: list[str], available_tools: list[Tool]) -> None:
"""
Validates that requested tools exist on the server.
Shared validation logic between both classes.
"""
available_tool_names = {tool.name for tool in available_tools}
missing_tools = set(requested_tool_names) - available_tool_names
if missing_tools:
message = (
f"The following tools were not found: {', '.join(missing_tools)}. "
f"Available tools: {', '.join(available_tool_names)}"
)
raise MCPToolNotFoundError(
message=message, tool_name=next(iter(missing_tools)), available_tools=list(available_tool_names)
)

def create_tool_invoke_function(self, tool_name: str, invocation_timeout: float):
"""
Creates the invoke function that both classes use.
MCPTool uses this directly, MCPToolset uses this in its closure factory.
"""

def invoke_tool(**kwargs) -> str:
"""Unified invoke logic - no more duplication"""
logger.debug(f"Invoking tool '{tool_name}' with args: {kwargs}")
try:

async def invoke():
logger.debug(f"Inside invoke coroutine for '{tool_name}'")
if self._client is None:
message = "MCP client is not connected"
raise MCPConnectionError(message=message, operation="invoke")
result = await asyncio.wait_for(
self._client.call_tool(tool_name, kwargs), timeout=invocation_timeout
)
logger.debug(f"Invoke successful for '{tool_name}'")
return result

logger.debug(f"About to run invoke for '{tool_name}'")
result = AsyncExecutor.get_instance().run(invoke(), timeout=invocation_timeout)
logger.debug(f"Invoke complete for '{tool_name}', result type: {type(result)}")
return result
except (MCPError, TimeoutError) as e:
logger.debug(f"Known error during invoke of '{tool_name}': {e!s}")
# Pass through known errors
raise
except Exception as e:
# Wrap other errors
logger.debug(f"Unknown error during invoke of '{tool_name}': {e!s}")
message = f"Failed to invoke tool '{tool_name}' with args: {kwargs} , got error: {e!s}"
raise MCPInvocationError(message, tool_name, kwargs) from e

return invoke_tool

def get_client(self) -> MCPClient | None:
"""Allow direct access to client for MCPTool's async method access"""
return self._client

def close(self):
"""Shared cleanup logic"""
if hasattr(self, "_worker") and self._worker:
try:
self._worker.stop()
except Exception as e:
logger.debug(f"Error during worker stop: {e!s}")

# Clear references
self._worker = None
self._client = None


class MCPTool(Tool):
"""
A Tool that represents a single tool from an MCP server.
Expand Down Expand Up @@ -706,41 +908,35 @@ def __init__(
logger.debug(f"TOOL: Initializing MCPTool '{name}'")

try:
# Create client and spin up a long-lived worker that keeps the
# connect/close lifecycle inside one coroutine.
self._client = server_info.create_client()
logger.debug(f"TOOL: Created client for MCPTool '{name}'")

# The worker starts immediately and blocks here until the connection
# is established (or fails), returning the tool list.
self._worker = _MCPClientSessionManager(self._client, timeout=connection_timeout)
# Use shared connection logic
self._connection_manager = MCPConnectionManager(server_info, connection_timeout)
available_tools = self._connection_manager.connect_and_discover_tools()

tools = self._worker.tools()
# Handle no tools case
if not tools:
if not available_tools:
logger.debug(f"TOOL: No tools found for '{name}'")
message = "No tools available on server"
raise MCPToolNotFoundError(message, tool_name=name)

# Validate that the requested tool exists
self._connection_manager.validate_requested_tools([name], available_tools)

# Find the specified tool
tool_dict = {t.name: t for t in tools}
tool_dict = {t.name: t for t in available_tools}
logger.debug(f"TOOL: Available tools: {list(tool_dict.keys())}")
tool_info = tool_dict[name] # Safe to use direct access since validation passed

tool_info = tool_dict.get(name)
logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class")

if not tool_info:
available = list(tool_dict.keys())
logger.debug(f"TOOL: Tool '{name}' not found in available tools")
message = f"Tool '{name}' not found on server. Available tools: {', '.join(available)}"
raise MCPToolNotFoundError(message, tool_name=name, available_tools=available)
# Create shared invoke function
invoke_func = self._connection_manager.create_tool_invoke_function(name, invocation_timeout)

logger.debug(f"TOOL: Found tool '{name}', initializing Tool parent class")
# Initialize the parent class
super().__init__(
name=name,
description=description or tool_info.description,
parameters=tool_info.inputSchema,
function=self._invoke_tool,
function=invoke_func,
)
logger.debug(f"TOOL: Initialization complete for '{name}'")

Expand All @@ -749,55 +945,11 @@ def __init__(
# fail because of an MCPToolNotFoundError
self.close()

# Extract more detailed error information from TaskGroup/ExceptionGroup exceptions
error_message = str(e)
# Handle ExceptionGroup to extract more useful error messages
if isinstance(e, ExceptionGroup):
if e.exceptions:
first_exception = e.exceptions[0]
error_message = (
first_exception.message if hasattr(first_exception, "message") else str(first_exception)
)

# Ensure we always have a meaningful error message
if not error_message or error_message.strip() == "":
# Provide platform-independent fallback message for connection errors
error_message = f"Connection failed to MCP server (using {type(server_info).__name__})"

message = f"Failed to initialize MCPTool '{name}': {error_message}"
raise MCPConnectionError(message=message, server_info=server_info, operation="initialize") from e

def _invoke_tool(self, **kwargs: Any) -> str:
"""
Synchronous tool invocation.

:param kwargs: Arguments to pass to the tool
:returns: JSON string representation of the tool invocation result
"""
logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}")
try:

async def invoke():
logger.debug(f"TOOL: Inside invoke coroutine for '{self.name}'")
result = await asyncio.wait_for(
self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout
)
logger.debug(f"TOOL: Invoke successful for '{self.name}'")
return result

logger.debug(f"TOOL: About to run invoke for '{self.name}'")
result = AsyncExecutor.get_instance().run(invoke(), timeout=self._invocation_timeout)
logger.debug(f"TOOL: Invoke complete for '{self.name}', result type: {type(result)}")
return result
except (MCPError, TimeoutError) as e:
logger.debug(f"TOOL: Known error during invoke of '{self.name}': {e!s}")
# Pass through known errors
raise
except Exception as e:
# Wrap other errors
logger.debug(f"TOOL: Unknown error during invoke of '{self.name}': {e!s}")
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
raise MCPInvocationError(message, self.name, kwargs) from e
# Use shared error handling
error_message = _create_connection_error_message(
server_info=server_info, exception=e, operation="initialize", context=f"MCPTool '{name}'"
)
raise MCPConnectionError(message=error_message, server_info=server_info, operation="initialize") from e

async def ainvoke(self, **kwargs: Any) -> str:
"""
Expand All @@ -809,7 +961,11 @@ async def ainvoke(self, **kwargs: Any) -> str:
:raises TimeoutError: If the operation times out
"""
try:
return await asyncio.wait_for(self._client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
client = self._connection_manager.get_client()
if client is None:
message = "MCP client is not connected"
raise MCPConnectionError(message=message, operation="ainvoke")
return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
except asyncio.TimeoutError as e:
message = f"Tool invocation timed out after {self._invocation_timeout} seconds"
raise TimeoutError(message) from e
Expand Down Expand Up @@ -881,13 +1037,11 @@ def from_dict(cls, data: dict[str, Any]) -> "Tool":

def close(self):
"""Close the tool synchronously."""
if hasattr(self, "_client") and self._client:
if hasattr(self, "_connection_manager") and self._connection_manager:
try:
# Tell the background worker to shut down gracefully.
if hasattr(self, "_worker") and self._worker:
self._worker.stop()
self._connection_manager.close()
except Exception as e:
logger.debug(f"TOOL: Error during synchronous worker stop: {e!s}")
logger.debug(f"TOOL: Error during connection manager close: {e!s}")

def __del__(self):
"""Cleanup resources when the tool is garbage collected."""
Expand Down
Loading