From 8a72c3c5ff44fe2d8a7a30b8bffb3bf6cfb040a3 Mon Sep 17 00:00:00 2001 From: syossan27 Date: Tue, 8 Jul 2025 17:59:02 +0900 Subject: [PATCH] fix: enhance error handling and logging for server capabilities --- src/mcp_scan/MCPScanner.py | 21 +++++++++++++++++---- src/mcp_scan/mcp_client.py | 24 ++++++++++++++++++------ src/mcp_scan/verify_api.py | 27 ++++++++++++++++++++++++++- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index b2eec04..67f2d4b 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -5,7 +5,7 @@ from collections.abc import Callable from typing import Any -from mcp_scan.models import ScanError, ScanPathResult, ServerScanResult +from mcp_scan.models import ScanError, ScanPathResult, ServerScanResult, StdioServer from .mcp_client import check_server_with_timeout, scan_mcp_config_file from .StorageFile import StorageFile @@ -160,9 +160,22 @@ async def scan_server(self, server: ServerScanResult, inspect_only: bool = False logger.info("Scanning server: %s, inspect_only: %s", server.name, inspect_only) result = server.clone() try: - result.signature = await check_server_with_timeout( - server.server, self.server_timeout, self.suppress_mcpserver_io - ) + if isinstance(server.server, StdioServer) and server.server.command == 'docker': + logger.info("Docker command detected, applying special timeout handling") + timeout_task = asyncio.create_task( + check_server_with_timeout(server.server, self.server_timeout, self.suppress_mcpserver_io) + ) + try: + result.signature = await asyncio.wait_for(timeout_task, self.server_timeout + 2) + except asyncio.TimeoutError: + timeout_task.cancel() + logger.error("Docker command timed out for server: %s", server.name) + raise asyncio.TimeoutError(f"Docker command timed out after {self.server_timeout + 2} seconds") + else: + result.signature = await check_server_with_timeout( + server.server, self.server_timeout, self.suppress_mcpserver_io + ) + logger.debug( "Server %s has %d prompts, %d resources, %d tools", server.name, diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index aea2687..b8dc41d 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -88,23 +88,35 @@ async def _check_server(verbose: bool) -> ServerSignature: try: prompts = (await session.list_prompts()).prompts logger.debug("Found %d prompts", len(prompts)) - except Exception: - logger.exception("Failed to list prompts") + except Exception as e: + logger.debug("Failed to list prompts: %s", str(e)) + if "prompts not supported" in str(e) or "Method not found" in str(e): + logger.debug("Server does not support prompts capability or method, skipping") + else: + logger.exception("Failed to list prompts") if isinstance(server_config, StdioServer) or meta.capabilities.resources: logger.debug("Fetching resources") try: resources = (await session.list_resources()).resources logger.debug("Found %d resources", len(resources)) - except Exception: - logger.exception("Failed to list resources") + except Exception as e: + logger.debug("Failed to list resources: %s", str(e)) + if "resources not supported" in str(e) or "Method not found" in str(e): + logger.debug("Server does not support resources capability or method, skipping") + else: + logger.exception("Failed to list resources") if isinstance(server_config, StdioServer) or meta.capabilities.tools: logger.debug("Fetching tools") try: tools = (await session.list_tools()).tools logger.debug("Found %d tools", len(tools)) - except Exception: - logger.exception("Failed to list tools") + except Exception as e: + logger.debug("Failed to list tools: %s", str(e)) + if "tools not supported" in str(e) or "Method not found" in str(e): + logger.debug("Server does not support tools capability or method, skipping") + else: + logger.exception("Failed to list tools") logger.info("Server check completed successfully") return ServerSignature( metadata=meta, diff --git a/src/mcp_scan/verify_api.py b/src/mcp_scan/verify_api.py index a0a9ecd..4db9dd2 100644 --- a/src/mcp_scan/verify_api.py +++ b/src/mcp_scan/verify_api.py @@ -30,6 +30,17 @@ async def verify_scan_path_public_api(scan_path: ScanPathResult, base_url: str) payload.root.append(server.signature) # Server signatures do not contain any information about the user setup. Only about the server itself. try: + # Check if there's data to send + if not payload.root: + # If no data, skip API call and return appropriate result + for server in output_path.servers: + if server.signature is None: + continue + server.result = [ + EntityScanResult(status="no server signature data available") for _ in server.entities + ] if server.entities else [] + return output_path + async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, data=payload.model_dump_json()) as response: if response.status == 200: @@ -52,7 +63,7 @@ async def verify_scan_path_public_api(scan_path: ScanPathResult, base_url: str) if server.signature is not None: server.result = [ EntityScanResult(status="could not reach verification server " + errstr) for _ in server.entities - ] + ] if server.entities else [] return output_path @@ -71,6 +82,17 @@ async def verify_scan_path_locally(scan_path: ScanPathResult) -> ScanPathResult: if server.signature is not None: for entity in server.entities: tools_to_scan.append(entity_to_tool(entity)) + + # If no tools to scan, return early with appropriate message + if not tools_to_scan: + for server in output_path.servers: + if server.signature is None: + continue + server.result = [ + EntityScanResult(status="no tools available for scanning") for _ in server.entities + ] if server.entities else [] + return output_path + messages = [{"tools": [tool.model_dump() for tool in tools_to_scan]}] policy = LocalPolicy.from_string(get_policy()) @@ -87,6 +109,9 @@ async def verify_scan_path_locally(scan_path: ScanPathResult) -> ScanPathResult: for server in output_path.servers: if server.signature is None: continue + if not server.entities: + server.result = [] + continue server.result = results[: len(server.entities)] results = results[len(server.entities) :] if results: