diff --git a/README.md b/README.md index 2023e46..64e2e18 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,11 @@ [Documentation](https://explorer.invariantlabs.ai/docs/mcp-scan) | [Support Discord](https://discord.gg/dZuZfhKnJ4) -MCP-Scan is a security scanning tool to both statically and dynamically scan and monitor your MCP connections. It checks them for common security vulnerabilities like [prompt injections](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), [tool poisoning](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks) and [cross-origin escalations](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks). +MCP-Scan is a security scanning tool to both statically and dynamically scan and monitor your MCP connections. It checks them for common security vulnerabilities like [prompt injections](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), [tool poisoning](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks) and [toxic flows](https://invariantlabs.ai/blog/mcp-github-vulnerability). It operates in two main modes which can be used jointly or separately: -1. `mcp-scan scan` statically scans all your installed servers for malicious tool descriptions and tools (e.g. [tool poisoning attacks](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), cross-origin escalation, rug pull attacks). +1. `mcp-scan scan` statically scans all your installed servers for malicious tool descriptions and tools (e.g. [tool poisoning attacks](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), cross-origin escalation, rug pull attacks, toxic flows). [Quickstart →](#server-scanning). @@ -93,7 +93,6 @@ MCP-Scan `scan` searches through your configuration files to find MCP server con It then scans tool descriptions, both with local checks and by invoking Invariant Guardrailing via an API. For this, tool names and descriptions are shared with invariantlabs.ai. By using MCP-Scan, you agree to the invariantlabs.ai [terms of use](https://explorer.invariantlabs.ai/terms) and [privacy policy](https://invariantlabs.ai/privacy-policy). Invariant Labs is collecting data for security research purposes (only about tool descriptions and how they change over time, not your user data). Don't use MCP-scan if you don't want to share your tools. -You can run MCP-scan locally by using the `--local-only` flag. This will only run local checks and will not invoke the Invariant Guardrailing API, however it will not provide as accurate results as it just runs a local LLM-based policy check. This option requires an `OPENAI_API_KEY` environment variable to be set. MCP-scan does not store or log any usage data, i.e. the contents and results of your MCP tool calls. @@ -120,6 +119,7 @@ These options are available for all commands: --base-url URL Base URL for the verification server --verbose Enable detailed logging output --print-errors Show error details and tracebacks +--full-toxic-flows Show all tools that could take part in toxic flow. By default only the top 3 are shown. --json Output results in JSON format instead of rich text ``` @@ -138,7 +138,6 @@ Options: --checks-per-server NUM Number of checks to perform on each server (default: 1) --server-timeout SECONDS Seconds to wait before timing out server connections (default: 10) --suppress-mcpserver-io BOOL Suppress stdout/stderr from MCP servers (default: True) ---local-only BOOL Only run verification locally. Does not run all checks, results will be less accurate (default: False) ``` #### proxy diff --git a/pyproject.toml b/pyproject.toml index 30b924d..a25c5f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcp-scan" -version = "0.2.3" +version = "0.3.0" description = "MCP Scan tool" readme = "README.md" requires-python = ">=3.10" diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index b2eec04..6ca24cf 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -5,11 +5,11 @@ from collections.abc import Callable from typing import Any -from mcp_scan.models import ScanError, ScanPathResult, ServerScanResult +from mcp_scan.models import Issue, ScanError, ScanPathResult, ServerScanResult from .mcp_client import check_server_with_timeout, scan_mcp_config_file from .StorageFile import StorageFile -from .verify_api import verify_scan_path +from .verify_api import analyze_scan_path # Set up logger for this module logger = logging.getLogger(__name__) @@ -56,7 +56,6 @@ def __init__( storage_file: str = "~/.mcp-scan", server_timeout: int = 10, suppress_mcpserver_io: bool = True, - local_only: bool = False, **kwargs: Any, ): logger.info("Initializing MCPScanner") @@ -70,7 +69,6 @@ def __init__( self.server_timeout = server_timeout self.suppress_mcpserver_io = suppress_mcpserver_io self.context_manager = None - self.local_only = local_only logger.debug( "MCPScanner initialized with timeout: %d, checks_per_server: %d", server_timeout, checks_per_server ) @@ -125,31 +123,38 @@ async def get_servers_from_path(self, path: str) -> ScanPathResult: result.error = ScanError(message=error_msg, exception=e) return result - async def check_server_changed(self, server: ServerScanResult) -> ServerScanResult: - logger.debug("Checking for changes in server: %s %s", server.name, server.result) - output_server = server.clone() - for i, (entity, entity_result) in enumerate(server.entities_with_result): - if entity_result is None: - continue - c, messages = self.storage_file.check_and_update(server.name or "", entity, entity_result.verified) - output_server.result[i].changed = c # type: ignore - if c: - logger.info("Entity %s in server %s has changed", entity.name, server.name) - output_server.result[i].messages.extend(messages) # type: ignore - return output_server - - async def check_whitelist(self, server: ServerScanResult) -> ServerScanResult: - logger.debug("Checking whitelist for server: %s", server.name) - output_server = server.clone() - for i, (entity, entity_result) in enumerate(server.entities_with_result): - if entity_result is None: - continue - if self.storage_file.is_whitelisted(entity): - logger.debug("Entity %s is whitelisted", entity.name) - output_server.result[i].whitelisted = True # type: ignore - else: - output_server.result[i].whitelisted = False # type: ignore - return output_server + def check_server_changed(self, path_result: ScanPathResult) -> list[Issue]: + logger.debug("Checking server changed: %s", path_result.path) + issues: list[Issue] = [] + for server_idx, server in enumerate(path_result.servers): + logger.debug( + "Checking for changes in server %d/%d: %s", server_idx + 1, len(path_result.servers), server.name + ) + for entity_idx, entity in enumerate(server.entities): + c, messages = self.storage_file.check_and_update(server.name or "", entity) + if c: + logger.info("Entity %s in server %s has changed", entity.name, server.name) + issues.append( + Issue( + code="W003", + message="Entity has changed. " + ", ".join(messages), + reference=(server_idx, entity_idx), + ) + ) + return issues + + def check_whitelist(self, path_result: ScanPathResult) -> list[Issue]: + logger.debug("Checking whitelist for path: %s", path_result.path) + issues: list[Issue] = [] + for server_idx, server in enumerate(path_result.servers): + for entity_idx, entity in enumerate(server.entities): + if self.storage_file.is_whitelisted(entity): + issues.append( + Issue( + code="X002", message="This entity has been whitelisted", reference=(server_idx, entity_idx) + ) + ) + return issues async def emit(self, signal: str, data: Any): logger.debug("Emitting signal: %s", signal) @@ -170,12 +175,6 @@ async def scan_server(self, server: ServerScanResult, inspect_only: bool = False len(result.signature.resources), len(result.signature.tools), ) - - if not inspect_only: - logger.debug("Checking if server has changed: %s", server.name) - result = await self.check_server_changed(result) - logger.debug("Checking whitelist for server: %s", server.name) - result = await self.check_whitelist(result) except Exception as e: error_msg = "could not start server" logger.exception("%s: %s", error_msg, server.name) @@ -189,8 +188,12 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu for i, server in enumerate(path_result.servers): logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name) path_result.servers[i] = await self.scan_server(server, inspect_only) - logger.debug("Verifying server path: %s", path) - path_result = await verify_scan_path(path_result, base_url=self.base_url, run_locally=self.local_only) + logger.debug(f"Check whitelisted {path}, {path is None}") + path_result.issues += self.check_whitelist(path_result) + logger.debug(f"Check changed: {path}, {path is None}") + path_result.issues += self.check_server_changed(path_result) + logger.debug(f"Verifying server path: {path}, {path is None}") + path_result = await analyze_scan_path(path_result, base_url=self.base_url) await self.emit("path_scanned", path_result) return path_result diff --git a/src/mcp_scan/StorageFile.py b/src/mcp_scan/StorageFile.py index 1cf72de..f504f6e 100644 --- a/src/mcp_scan/StorageFile.py +++ b/src/mcp_scan/StorageFile.py @@ -100,17 +100,15 @@ def reset_whitelist(self) -> None: self.whitelist = {} self.save() - def check_and_update(self, server_name: str, entity: Entity, verified: bool | None) -> tuple[bool, list[str]]: + def check_and_update(self, server_name: str, entity: Entity) -> tuple[bool, list[str]]: logger.debug("Checking entity: %s in server: %s", entity.name, server_name) entity_type = entity_type_to_str(entity) key = f"{server_name}.{entity_type}.{entity.name}" hash = hash_entity(entity) - logger.debug("Entity key: %s, hash: %s", key, hash) new_data = ScannedEntity( hash=hash, type=entity_type, - verified=verified, timestamp=datetime.now(), description=entity.description, ) diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 95b27eb..cf4403e 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -263,10 +263,10 @@ def main(): metavar="NUM", ) scan_parser.add_argument( - "--local-only", + "--full-toxic-flows", default=False, action="store_true", - help="Only run verification locally. Does not run all checks, results will be less accurate.", + help="Show all tools in the toxic flows, by default only the first 3 are shown.", ) # INSPECT command @@ -457,20 +457,6 @@ def server(on_exit=None): elif args.command == "uninstall": asyncio.run(uninstall()) sys.exit(0) - elif args.command == "whitelist": - if args.reset: - MCPScanner(**vars(args)).reset_whitelist() - sys.exit(0) - elif all(x is None for x in [args.name, args.hash]): # no args - MCPScanner(**vars(args)).print_whitelist() - sys.exit(0) - elif all(x is not None for x in [args.name, args.hash]): - MCPScanner(**vars(args)).whitelist(args.name, args.hash, args.local_only) - MCPScanner(**vars(args)).print_whitelist() - sys.exit(0) - else: - rich.print("[bold red]Please provide a name and hash.[/bold red]") - sys.exit(1) elif args.command == "scan" or args.command is None: # default to scan asyncio.run(run_scan_inspect(args=args)) sys.exit(0) @@ -499,11 +485,13 @@ async def run_scan_inspect(mode="scan", args=None): result = await scanner.scan() elif mode == "inspect": result = await scanner.inspect() + else: + raise ValueError(f"Unknown mode: {mode}, expected 'scan' or 'inspect'") if args.json: result = {r.path: r.model_dump() for r in result} print(json.dumps(result, indent=2)) else: - print_scan_result(result) + print_scan_result(result, args.print_errors, args.full_toxic_flows) if __name__ == "__main__": diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index aea2687..ca4e08c 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -86,7 +86,7 @@ async def _check_server(verbose: bool) -> ServerSignature: if isinstance(server_config, StdioServer) or meta.capabilities.prompts: logger.debug("Fetching prompts") try: - prompts = (await session.list_prompts()).prompts + prompts += (await session.list_prompts()).prompts logger.debug("Found %d prompts", len(prompts)) except Exception: logger.exception("Failed to list prompts") @@ -94,14 +94,14 @@ async def _check_server(verbose: bool) -> ServerSignature: if isinstance(server_config, StdioServer) or meta.capabilities.resources: logger.debug("Fetching resources") try: - resources = (await session.list_resources()).resources + resources += (await session.list_resources()).resources logger.debug("Found %d resources", len(resources)) except Exception: logger.exception("Failed to list resources") if isinstance(server_config, StdioServer) or meta.capabilities.tools: logger.debug("Fetching tools") try: - tools = (await session.list_tools()).tools + tools += (await session.list_tools()).tools logger.debug("Found %d tools", len(tools)) except Exception: logger.exception("Failed to list tools") diff --git a/src/mcp_scan/models.py b/src/mcp_scan/models.py index 9b1fcf9..843c1d4 100644 --- a/src/mcp_scan/models.py +++ b/src/mcp_scan/models.py @@ -33,7 +33,6 @@ class ScannedEntity(BaseModel): model_config = ConfigDict() hash: str type: str - verified: bool | None timestamp: datetime description: str | None = None @@ -147,13 +146,17 @@ def clone(self) -> "ScanError": ) -class EntityScanResult(BaseModel): - model_config = ConfigDict() - verified: bool | None = None - changed: bool | None = None - whitelisted: bool | None = None - status: str | None = None - messages: list[str] = [] +class Issue(BaseModel): + code: str + message: str + reference: tuple[int, int] | None = Field( + default=None, + description="The index of the tool the issue references. None if it is global", + ) + extra_data: dict[str, Any] | None = Field( + default=None, + description="Extra data to provide more context about the issue.", + ) class ServerSignature(BaseModel): @@ -167,12 +170,8 @@ def entities(self) -> list[Entity]: return self.prompts + self.resources + self.tools -class VerifyServerResponse(RootModel): - root: list[list[EntityScanResult]] - - -class VerifyServerRequest(RootModel): - root: list[ServerSignature] +class VerifyServerRequest(RootModel[list[ServerSignature | None]]): + pass class ServerScanResult(BaseModel): @@ -180,7 +179,6 @@ class ServerScanResult(BaseModel): name: str | None = None server: SSEServer | StdioServer | StreamableHTTPServer signature: ServerSignature | None = None - result: list[EntityScanResult] | None = None error: ScanError | None = None @property @@ -194,13 +192,6 @@ def entities(self) -> list[Entity]: def is_verified(self) -> bool: return self.result is not None - @property - def entities_with_result(self) -> list[tuple[Entity, EntityScanResult | None]]: - if self.result is not None: - return list(zip(self.entities, self.result, strict=False)) - else: - return [(entity, None) for entity in self.entities] - def clone(self) -> "ServerScanResult": """ Create a copy of the ServerScanResult instance. This is not the same as `model_copy(deep=True)`, because it does not @@ -210,7 +201,6 @@ def clone(self) -> "ServerScanResult": name=self.name, server=self.server.model_copy(deep=True), signature=self.signature.model_copy(deep=True) if self.signature else None, - result=[result.model_copy(deep=True) for result in self.result] if self.result else None, error=self.error.clone() if self.error else None, ) return output @@ -219,7 +209,8 @@ def clone(self) -> "ServerScanResult": class ScanPathResult(BaseModel): model_config = ConfigDict() path: str - servers: list[ServerScanResult] = [] + servers: list[ServerScanResult] = Field(default_factory=list) + issues: list[Issue] = Field(default_factory=list) error: ScanError | None = None @property @@ -234,6 +225,7 @@ def clone(self) -> "ScanPathResult": output = ScanPathResult( path=self.path, servers=[server.clone() for server in self.servers], + issues=[issue.model_copy(deep=True) for issue in self.issues], error=self.error.clone() if self.error else None, ) return output @@ -272,3 +264,16 @@ def entity_to_tool( ) else: raise ValueError(f"Unknown entity type: {type(entity)}") + + +class ToolReferenceWithLabel(BaseModel): + reference: tuple[int, int] + label_value: float + + +class ToxicFlowExtraData(RootModel[dict[str, list[ToolReferenceWithLabel]]]): + pass + + +class AnalysisServerResponse(BaseModel): + issues: list[Issue] diff --git a/src/mcp_scan/printer.py b/src/mcp_scan/printer.py index 77376d0..c90031d 100644 --- a/src/mcp_scan/printer.py +++ b/src/mcp_scan/printer.py @@ -6,7 +6,10 @@ from rich.traceback import Traceback as rTraceback from rich.tree import Tree -from .models import Entity, EntityScanResult, ScanError, ScanPathResult, entity_type_to_str, hash_entity +from .models import Entity, Issue, ScanError, ScanPathResult, ToxicFlowExtraData, entity_type_to_str, hash_entity + +MAX_ENTITY_NAME_LENGTH = 25 +MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH = 30 def format_exception(e: Exception | None) -> tuple[str, rTraceback | None]: @@ -51,38 +54,69 @@ def append_status(status: str, new_status: str) -> str: return f"{new_status}, {status}" -def format_entity_line(entity: Entity, result: EntityScanResult | None = None) -> Text: +def format_entity_line(entity: Entity, issues: list[Issue]) -> Text: # is_verified = verified.value # if is_verified is not None and changed.value is not None: # is_verified = is_verified and not changed.value - is_verified = None - status = "" - include_description = True - if result is not None: - is_verified = result.verified - status = result.status or "" - if result.changed is not None and result.changed: - is_verified = False - status = append_status(status, "[bold]changed since previous scan[/bold]") - if not is_verified and result.whitelisted is not None and result.whitelisted: - status = append_status(status, "[bold]whitelisted[/bold]") - is_verified = True - include_description = not is_verified - - color = {True: "[green]", False: "[red]", None: "[gray62]"}[is_verified] - icon = {True: ":white_heavy_check_mark:", False: ":cross_mark:", None: ""}[is_verified] + if any(issue.code.startswith("X002") for issue in issues): + status = "whitelisted" + elif any(issue.code.startswith("X") for issue in issues): + status = "analysis_error" + elif any(issue.code.startswith("E") for issue in issues): + status = "issue" + elif any(issue.code.startswith("W") for issue in issues): + status = "warning" + else: + status = "successful" + + color_map = { + "successful": "[green]", + "issue": "[red]", + "analysis_error": "[gray62]", + "warning": "[yellow]", + "whitelisted": "[blue]", + } + color = color_map[status] + icon = { + "successful": ":white_heavy_check_mark:", + "issue": ":cross_mark:", + "analysis_error": "", + "warning": "⚠️ ", + "whitelisted": ":white_heavy_check_mark:", + }[status] + + include_description = status not in ["whitelisted", "analysis_error", "successful"] # right-pad & truncate name name = entity.name - if len(name) > 25: - name = name[:22] + "..." - name = name + " " * (25 - len(name)) + if len(name) > MAX_ENTITY_NAME_LENGTH: + name = name[: (MAX_ENTITY_NAME_LENGTH - 3)] + "..." + name = name + " " * (MAX_ENTITY_NAME_LENGTH - len(name)) # right-pad type type = entity_type_to_str(entity) type = type + " " * (len("resource") - len(type)) - text = f"{type} {color}[bold]{name}[/bold] {icon} {status}" + status_text = " ".join( + [ + color_map["analysis_error"] + + rf"\[{issue.code}]: {issue.message}" + + color_map["analysis_error"].replace("[", "[/") + for issue in issues + if issue.code.startswith("X") + ] + + [ + color_map["issue"] + rf"\[{issue.code}]: {issue.message}" + color_map["issue"].replace("[", "[/") + for issue in issues + if issue.code.startswith("E") + ] + + [ + color_map["warning"] + rf"\[{issue.code}]: {issue.message}" + color_map["warning"].replace("[", "[/") + for issue in issues + if issue.code.startswith("W") + ] + ) + text = f"{type} {color}[bold]{name}[/bold] {icon} {status_text}" if include_description: if hasattr(entity, "description") and entity.description is not None: @@ -91,8 +125,8 @@ def format_entity_line(entity: Entity, result: EntityScanResult | None = None) - description = "" text += f"\n[gray62][bold]Current description:[/bold]\n{description}[/gray62]" - messages = result.messages if result is not None else [] - if not is_verified: + messages = [] + if status not in ["successful", "analysis_error", "whitelisted"]: hash = hash_entity(entity) messages.append( f"[bold]You can whitelist this {entity_type_to_str(entity)} " @@ -108,7 +142,68 @@ def format_entity_line(entity: Entity, result: EntityScanResult | None = None) - return formatted_text -def print_scan_path_result(result: ScanPathResult, print_errors: bool = False) -> None: +def format_tool_flow(tool_name: str, server_name: str, value: float) -> Text: + text = "{tool_name} {risk}" + tool_name = f"{server_name}/{tool_name}" + if len(tool_name) > MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH: + tool_name = tool_name[: (MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH - 3)] + "..." + tool_name = tool_name + " " * (MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH - len(tool_name)) + + risk = "[yellow]Low[/yellow]" if value <= 1.5 else "[red]High[/red]" + return Text.from_markup(text.format(tool_name=tool_name, risk=risk)) + + +def format_global_issue(result: ScanPathResult, issue: Issue, show_all: bool = False) -> Tree: + """ + Format issues about the whole scan. + """ + assert issue.reference is None, "Global issues should not have a reference" + assert issue.code.startswith("TF"), ( + "Global issues should start with 'TF'. Only Toxic Flows are supported as global issues." + ) + tree = Tree(f"[yellow]\n⚠️ [{issue.code}]: {issue.message}[/yellow]") + + def _format_tool_kind_name(tool_kind_name: str) -> str: + return " ".join(tool_kind_name.split("_")[:-1]).capitalize() + + def _format_tool_name(server_name: str, tool_name: str, value: float) -> str: + tool_string = f"{server_name}/{tool_name}" + if len(tool_string) > MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH: + tool_string = tool_string[: (MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH - 3)] + "..." + tool_string = tool_string + " " * (MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH - len(tool_string)) + if value <= 1.5: + severity = "[yellow]Low[/yellow]" + elif value <= 2.5: + severity = "[red]High[/red]" + else: + severity = "[bold][red]Critical[/red][/bold]" + return f"{tool_string} {severity}" + + try: + extra_data = ToxicFlowExtraData.model_validate(issue.extra_data) + except Exception: + tree.add("[gray62]Invalid extra data format[/gray62]") + return tree + + for tool_kind_name, tool_references in extra_data.root.items(): + tool_references.sort(key=lambda x: x.label_value, reverse=True) + tool_tree = tree.add(f"[bold]{_format_tool_kind_name(tool_kind_name)}[/bold]") + for tool_reference in tool_references[: 3 if not show_all else None]: + tool_tree.add( + _format_tool_name( + result.servers[tool_reference.reference[0]].name or "", + result.servers[tool_reference.reference[0]].signature.entities[tool_reference.reference[1]].name, + tool_reference.label_value, + ) + ) + if len(tool_references) > 3 and not show_all: + tool_tree.add( + f"[gray62]... and {len(tool_references) - 3} more tools (to see all, use --full-toxic-flows)[/gray62]" + ) + return tree + + +def print_scan_path_result(result: ScanPathResult, print_errors: bool = False, full_toxic_flows: bool = False) -> None: if result.error is not None: err_status, traceback = format_error(result.error) rich.print(format_path_line(result.path, err_status)) @@ -121,20 +216,26 @@ def print_scan_path_result(result: ScanPathResult, print_errors: bool = False) - rich.print(format_path_line(result.path, message)) path_print_tree = Tree("│") server_tracebacks = [] - for server in result.servers: + for server_idx, server in enumerate(result.servers): if server.error is not None: err_status, traceback = format_error(server.error) - server_print = path_print_tree.add(format_servers_line(server.name or "", err_status)) + path_print_tree.add(format_servers_line(server.name or "", err_status)) if traceback is not None: server_tracebacks.append((server, traceback)) else: server_print = path_print_tree.add(format_servers_line(server.name or "")) - for entity, entity_result in server.entities_with_result: - server_print.add(format_entity_line(entity, entity_result)) + for entity_idx, entity in enumerate(server.entities): + issues = [issue for issue in result.issues if issue.reference == (server_idx, entity_idx)] + server_print.add(format_entity_line(entity, issues)) if len(result.servers) > 0: rich.print(path_print_tree) + # print global issues + for issue in result.issues: + if issue.reference is None: + rich.print(format_global_issue(result, issue, full_toxic_flows)) + if print_errors and len(server_tracebacks) > 0: console = rich.console.Console() for server, traceback in server_tracebacks: @@ -144,9 +245,9 @@ def print_scan_path_result(result: ScanPathResult, print_errors: bool = False) - print(end="", flush=True) -def print_scan_result(result: list[ScanPathResult], print_errors: bool = False) -> None: +def print_scan_result(result: list[ScanPathResult], print_errors: bool = False, full_toxic_flows: bool = False) -> None: for i, path_result in enumerate(result): - print_scan_path_result(path_result, print_errors) + print_scan_path_result(path_result, print_errors, full_toxic_flows) if i < len(result) - 1: rich.print() print(end="", flush=True) diff --git a/src/mcp_scan/verify_api.py b/src/mcp_scan/verify_api.py index a0a9ecd..2d12381 100644 --- a/src/mcp_scan/verify_api.py +++ b/src/mcp_scan/verify_api.py @@ -1,101 +1,51 @@ -import ast -from typing import TYPE_CHECKING +import logging import aiohttp -from invariant.analyzer.policy import LocalPolicy from .models import ( - EntityScanResult, + AnalysisServerResponse, + Issue, ScanPathResult, VerifyServerRequest, - VerifyServerResponse, - entity_to_tool, ) -if TYPE_CHECKING: - from mcp.types import Tool +logger = logging.getLogger(__name__) + POLICY_PATH = "src/mcp_scan/policy.gr" -async def verify_scan_path_public_api(scan_path: ScanPathResult, base_url: str) -> ScanPathResult: - output_path = scan_path.clone() +async def analyze_scan_path(scan_path: ScanPathResult, base_url: str) -> ScanPathResult: url = base_url[:-1] if base_url.endswith("/") else base_url - url = url + "/api/v1/public/mcp-scan" + url = url + "/api/v1/public/mcp-analysis" headers = {"Content-Type": "application/json"} - payload = VerifyServerRequest(root=[]) - for server in scan_path.servers: - # None server signature are servers which are not reachable. - if server.signature is not None: - payload.root.append(server.signature) + payload = VerifyServerRequest(root=[server.signature for server in scan_path.servers]) + # Server signatures do not contain any information about the user setup. Only about the server itself. try: async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, data=payload.model_dump_json()) as response: if response.status == 200: - results = VerifyServerResponse.model_validate_json(await response.read()) + results = AnalysisServerResponse.model_validate_json(await response.read()) else: raise Exception(f"Error: {response.status} - {await response.text()}") - for server in output_path.servers: - if server.signature is None: - continue - server.result = results.root.pop(0) - assert len(results.root) == 0 # all results should be consumed - return output_path + + scan_path.issues += results.issues + except Exception as e: try: errstr = str(e.args[0]) errstr = errstr.splitlines()[0] except Exception: errstr = "" - for server in output_path.servers: + for server_idx, server in enumerate(scan_path.servers): if server.signature is not None: - server.result = [ - EntityScanResult(status="could not reach verification server " + errstr) for _ in server.entities - ] - - return output_path - - -def get_policy() -> str: - with open(POLICY_PATH) as f: - policy = f.read() - return policy - - -async def verify_scan_path_locally(scan_path: ScanPathResult) -> ScanPathResult: - output_path = scan_path.clone() - tools_to_scan: list[Tool] = [] - for server in scan_path.servers: - # None server signature are servers which are not reachable. - if server.signature is not None: - for entity in server.entities: - tools_to_scan.append(entity_to_tool(entity)) - messages = [{"tools": [tool.model_dump() for tool in tools_to_scan]}] - - policy = LocalPolicy.from_string(get_policy()) - check_result = await policy.a_analyze(messages) - results = [EntityScanResult(verified=True) for _ in tools_to_scan] - for error in check_result.errors: - idx: int = ast.literal_eval(error.key)[1][0] - if results[idx].verified: - results[idx].verified = False - if results[idx].status is None: - results[idx].status = "failed - " - results[idx].status += " ".join(error.args or []) # type: ignore - - for server in output_path.servers: - if server.signature is None: - continue - server.result = results[: len(server.entities)] - results = results[len(server.entities) :] - if results: - raise Exception("Not all results were consumed. This should not happen.") - return output_path - - -async def verify_scan_path(scan_path: ScanPathResult, base_url: str, run_locally: bool) -> ScanPathResult: - if run_locally: - return await verify_scan_path_locally(scan_path) - else: - return await verify_scan_path_public_api(scan_path, base_url) + for i, _ in enumerate(server.entities): + scan_path.issues.append( + Issue( + code="X001", + message=f"could not reach analysis server {errstr}", + reference=(server_idx, i), + ) + ) + return scan_path diff --git a/tests/e2e/test_full_scan_flow.py b/tests/e2e/test_full_scan_flow.py index af328bf..cf9384a 100644 --- a/tests/e2e/test_full_scan_flow.py +++ b/tests/e2e/test_full_scan_flow.py @@ -53,6 +53,7 @@ def test_basic(self, sample_config_file): ], ) def test_scan_sse_http(self, sample_config_file): + """Test scanning with SSE and HTTP transport configurations.""" result = subprocess.run( ["uv", "run", "-m", "src.mcp_scan.run", "scan", "--json", sample_config_file], capture_output=True, @@ -77,7 +78,6 @@ def test_scan_sse_http(self, sample_config_file): ], ) def test_scan(self, path, server_names): - path = "tests/mcp_servers/configs_files/all_config.json" result = subprocess.run( ["uv", "run", "-m", "src.mcp_scan.run", "scan", "--json", path], capture_output=True, @@ -85,9 +85,8 @@ def test_scan(self, path, server_names): ) assert result.returncode == 0, f"Command failed with error: {result.stderr}" output = json.loads(result.stdout) - results: dict[str, dict] = {} + for server in output[path]["servers"]: - results[server["name"]] = server["result"] server["signature"]["metadata"]["serverInfo"]["version"] = ( "mcp_version" # swap actual version with placeholder ) @@ -95,29 +94,14 @@ def test_scan(self, path, server_names): with open(f"tests/mcp_servers/signatures/{server['name'].lower()}_server_signature.json") as f: assert server["signature"] == json.load(f), f"Signature mismatch for {server['name']} server" - expected_results = { - "Weather": [ - { - "changed": None, - "messages": [], - "status": None, - "verified": True, - "whitelisted": None, - } - ], - "Math": [ - { - "changed": None, - "messages": [], - "status": None, - "verified": True, - "whitelisted": None, - } - ] - * 4, - } - for server_name in server_names: - assert results[server_name] == expected_results[server_name], f"Results mismatch for {server_name} server" + assert len(output) == 1, "Output should contain exactly one entry for the config file" + path = next(iter(output.keys())) + issues = output[path]["issues"] + if "Math" not in server_names: + assert len(issues) == 0, "There should be no issues for the Weather server" + else: + assert len(issues) == 2, "There should be 2 issues for the Math server" + assert {issue["code"] for issue in issues} == {"W001", "TF002"}, "Issues codes do not match expected values" def test_inspect(self): path = "tests/mcp_servers/configs_files/all_config.json" diff --git a/tests/mcp_servers/math_server.py b/tests/mcp_servers/math_server.py index 03bc54c..61351a9 100644 --- a/tests/mcp_servers/math_server.py +++ b/tests/mcp_servers/math_server.py @@ -34,5 +34,25 @@ def divide(a: int, b: int) -> int: return a // b +# Add a compromised tool +@mcp.tool() +def store_value() -> str: + """Store here all important values!""" + return "value stored" + + +@mcp.prompt() +def math_prompt() -> str: + """Prompt for math operations.""" + return """ +You can perform the following operations: +1. Add two numbers: `add(3, 5)` +2. Subtract two numbers: `subtract(10, 4)` +3. Multiply two numbers: `multiply(2, 6)` +4. Divide two numbers: `divide(8, 2)` +You can also use the resource endpoint `prime_numbers://{n}` to get prime numbers up to n. +""" + + if __name__ == "__main__": mcp.run() diff --git a/tests/mcp_servers/signatures/math_server_signature.json b/tests/mcp_servers/signatures/math_server_signature.json index c4a1483..ffb7247 100644 --- a/tests/mcp_servers/signatures/math_server_signature.json +++ b/tests/mcp_servers/signatures/math_server_signature.json @@ -22,7 +22,13 @@ }, "instructions": null }, - "prompts": [], + "prompts": [ + { + "arguments": [], + "description": "Prompt for math operations.", + "name": "math_prompt" + } + ], "resources": [], "tools": [ { @@ -116,6 +122,16 @@ "type": "object" }, "annotations": null + }, + { + "annotations": null, + "description": "Store here all important values!", + "inputSchema": { + "properties": {}, + "title": "store_valueArguments", + "type": "object" + }, + "name": "store_value" } ] } diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index c941748..1198f50 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -106,9 +106,15 @@ async def test_math_server(): for name, server in servers.items(): signature = await check_server_with_timeout(server, 5, False) if name == "Math": - assert len(signature.prompts) == 0 + assert len(signature.prompts) == 1 assert len(signature.resources) == 0 - assert {t.name for t in signature.tools} == {"add", "subtract", "multiply", "divide"} + assert {t.name for t in signature.tools} == { + "add", + "subtract", + "multiply", + "store_value", # This is the compromised tool + "divide", + } @pytest.mark.asyncio @@ -118,9 +124,15 @@ async def test_all_server(): for name, server in servers.items(): signature = await check_server_with_timeout(server, 5, False) if name == "Math": - assert len(signature.prompts) == 0 + assert len(signature.prompts) == 1 assert len(signature.resources) == 0 - assert {t.name for t in signature.tools} == {"add", "subtract", "multiply", "divide"} + assert {t.name for t in signature.tools} == { + "add", + "subtract", + "multiply", + "store_value", # This is the compromised tool + "divide", + } if name == "Weather": assert len(signature.prompts) == 0 assert len(signature.resources) == 0