From 12e5d2925aeba97c4e411052f64c31a4e1303301 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Tue, 22 Apr 2025 21:18:14 +0200 Subject: [PATCH 01/21] feat: mcp install gateway --- src/mcp_scan/MCPScanner.py | 20 ++++ src/mcp_scan/cli.py | 50 +++++++++ src/mcp_scan/gateway.py | 209 +++++++++++++++++++++++++++++++++++++ tests/test_configs.json | 32 ++++++ 4 files changed, 311 insertions(+) create mode 100644 src/mcp_scan/gateway.py create mode 100644 tests/test_configs.json diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index ce5f5d4..398159f 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -1,4 +1,10 @@ import os +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.client.sse import sse_client +from typing import Type + +import json import textwrap import asyncio from uu import Error @@ -11,6 +17,20 @@ from .StorageFile import StorageFile from .verify_api import verify_server from typing import Any +from .models import ( + VSCodeConfigFile, + VSCodeMCPConfig, + ClaudeConfigFile, + SSEServer, + StdioServer, + MCPConfig, +) +from .suppressIO import SuppressStd +from collections import namedtuple +from datetime import datetime +from hashlib import md5 +import pyjson5 +from .utils import rebalance_command_args def format_err_str(e: Exception, max_length: int | None=None) -> str: diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 7ddb75e..eb79c79 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -2,6 +2,7 @@ import argparse from .MCPScanner import MCPScanner from .StorageFile import StorageFile +from mcp_scan.gateway import MCPGatewayInstaller, MCPGatewayConfig import rich from .version import version_info import psutil @@ -209,6 +210,43 @@ def main(): help="Hash of the entity to whitelist", metavar="HASH", ) + # install + install_parser = subparsers.add_parser("install", help="Install Invariant Gateway") + install_parser.add_argument( + "files", + type=str, + nargs="*", + default=WELL_KNOWN_MCP_PATHS, + help="Different file locations to scan. This can include custom file locations as long as they are in an expected format, including Claude, Cursor or VSCode format.", + ) + install_parser.add_argument( + "--project_name", + type=str, + default="mcp-gateway", + help="Project name for the Invariant Gateway", + ) + install_parser.add_argument( + "--api-key", + type=str, + required=True, + help="api key for the Invariant Gateway", + ) + install_parser.add_argument( + "--local-only", + default=False, + action="store_true", + help="Prevent pushing traces to the explorer.", + ) + + # uninstall + uninstall_parser = subparsers.add_parser("uninstall", help="Uninstall Invariant Gateway") + uninstall_parser.add_argument( + "files", + type=str, + nargs="*", + default=WELL_KNOWN_MCP_PATHS, + help="Different file locations to scan. This can include custom file locations as long as they are in an expected format, including Claude, Cursor or VSCode format.", + ) # HELP command help_parser = subparsers.add_parser( @@ -247,6 +285,18 @@ def main(): elif args.command == 'inspect': MCPScanner(**vars(args)).inspect() sys.exit(0) + elif args.command == 'install': + installer = MCPGatewayInstaller(paths=args.files) + installer.install(gateway_config=MCPGatewayConfig( + project_name=args.project_name, + push_explorer=not args.local_only, + api_key=args.api_key, + ), verbose=True) + # install logic here + elif args.command == 'uninstall': + installer = MCPGatewayInstaller(paths=args.files) + installer.uninstall(verbose=True) + # uninstall logic here elif args.command == 'whitelist': if args.reset: MCPScanner(**vars(args)).reset_whitelist() diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py new file mode 100644 index 0000000..985efcd --- /dev/null +++ b/src/mcp_scan/gateway.py @@ -0,0 +1,209 @@ +from mcp_scan.models import StdioServer, SSEServer, MCPConfig +from mcp_scan.MCPScanner import scan_config_file, format_path_line, format_servers_line +from pydantic import BaseModel +from rich.tree import Tree +import argparse +import json +import os +import rich + +parser = argparse.ArgumentParser( + description="MCP-scan CLI", + prog="invariant-gateway@latest mcp", +) + +parser.add_argument( + "--project-name", + type=str, + required=True, +) +parser.add_argument( + "--push-explorer", + action="store_true", +) +parser.add_argument( + "--exec", + type=str, + required=True, + nargs=argparse.REMAINDER +) + +class MCPServerIsNotGateway(Exception): + pass + + +class MCPServerAlreadyGateway(Exception): + pass + + +class MCPGatewayConfig(BaseModel): + project_name: str + push_explorer: bool + api_key: str + +def is_invariant_installed(server: StdioServer) -> bool: + if server.args is None: + return False + return server.args[0] == "invariant-gateway@latest" + +def install_gateway( + server: StdioServer, config: MCPGatewayConfig, +) -> StdioServer: + """ + Install the gateway for the given server. + """ + if is_invariant_installed(server): + raise MCPServerAlreadyGateway() + return StdioServer( + command="uvx", + args= [ + "invariant-gateway@latest", + "mcp", + "--project-name", + config.project_name, + ] + (["--push-explorer"] if config.push_explorer else []) + [ + "--exec", + server.command + ] + (server.args if server.args else []), + env=server.env | {"INVARIANT_API_KEY": config.api_key}, + ) + +def uninstall_gateway( + server: StdioServer, +) -> StdioServer: + """ + Uninstall the gateway for the given server. + """ + + if not is_invariant_installed(server): + raise MCPServerIsNotGateway() + + assert isinstance(server.args, list), "args is not a list" + args = parser.parse_args(server.args[2:]) + new_env = {k: v for k, v in server.env.items() if k != "INVARIANT_API_KEY"} + assert args.exec is not None, "exec is None" + assert args.exec, "exec is empty" + return StdioServer( + command=args.exec[0], + args=args.exec[1:], + env=new_env, + ) + +def format_install_line(server: str, status: str, success: bool | None) -> Tree: + color = {True: "[green]", False: "[red]", None: "[gray62]"}[success] + + if len(server) > 25: + server = server[:22] + "..." + server = server + " " * (25 - len(server)) + icon = {True: ":white_heavy_check_mark:", False: ":cross_mark:", None: ""}[ + success + ] + + text = f"{color}[bold]{server}[/bold]{icon} {status}{color.replace('[', '[/')}" + return rich.text.Text.from_markup(text) + +class MCPGatewayInstaller: + """ + A class to install and uninstall the gateway for a given server. + """ + def __init__(self, paths: list[str], ) -> None: + self.paths = paths + + + def install(self, gateway_config: MCPGatewayConfig, verbose: bool = False, ) -> None: + for path in self.paths: + config: MCPConfig | None = None + try: + config = scan_config_file(path) + status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" + except FileNotFoundError: + status = f"file does not exist" + except Exception: + status = f"could not parse file" + if verbose: + rich.print(format_path_line(path, status, operation="Installing Gateway")) + if config is None: + continue + + path_print_tree = Tree("│") + new_servers: dict[str, SSEServer | StdioServer] = {} + for name, server in config.get_servers().items(): + if isinstance(server, StdioServer): + try: + new_servers[name] = install_gateway(server, gateway_config) + path_print_tree.add(format_install_line(server=name, status="Installed", success=True)) + except MCPServerAlreadyGateway: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="Already installed", success=True)) + except Exception: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="Failed to install", success=False)) + + else: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="sse servers not supported yet", success=False)) + + if verbose: + rich.print(path_print_tree) + config.set_servers(new_servers) + with open(os.path.expanduser(path), "w") as f: + f.write(config.model_dump_json(indent=4) + "\n") + + def uninstall(self, verbose: bool = False) -> None: + for path in self.paths: + config: MCPConfig | None = None + try: + config = scan_config_file(path) + status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" + except FileNotFoundError: + status = f"file does not exist" + except Exception: + status = f"could not parse file" + if verbose: + rich.print(format_path_line(path, status, operation="Installing Gateway")) + if config is None: + continue + + path_print_tree = Tree("│") + config = scan_config_file(path) + new_servers: dict[str, SSEServer | StdioServer] = {} + for name, server in config.get_servers().items(): + if isinstance(server, StdioServer): + try: + new_servers[name] = uninstall_gateway(server) + path_print_tree.add(format_install_line(server=name, status="Uninstalled", success=True)) + except MCPServerIsNotGateway: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="Already not installed", success=True)) + except Exception: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="Failed to uninstall", success=False)) + else: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="sse servers not supported yet", success=None)) + config.set_servers(new_servers) + if verbose: + rich.print(path_print_tree) + with open(os.path.expanduser(path), "w") as f: + f.write(config.model_dump_json(indent=4) + "\n") + +if __name__ == "__main__": + gateway_conf = MCPGatewayConfig( + project_name="test", + push_explorer=True, + api_key="inv-bb02e19170dd24aad2a13dceb17082a6d91b16597cb7bb579b9b777d78e54aaf", + ) + installer = MCPGatewayInstaller( + paths=["tests/test_configs.json"], + config=gateway_conf, + ) + confs = [scan_config_file(path) for path in installer.paths] + installer.install() + print("Installed") + installer.uninstall() + print("Uninstalled") + new_confs = [scan_config_file(path) for path in installer.paths] + for conf, new_conf in zip(confs, new_confs): + print(conf.model_dump_json(indent=4)) + print(new_conf.model_dump_json(indent=4)) + assert conf == new_conf \ No newline at end of file diff --git a/tests/test_configs.json b/tests/test_configs.json new file mode 100644 index 0000000..62f8e35 --- /dev/null +++ b/tests/test_configs.json @@ -0,0 +1,32 @@ +{ + "mcpServers": { + "Random Facts MCP Server": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp[cli]", + "mcp", + "run", + "/Users/marcomilanta/Documents/invariant/mcp-injection-experiments/whatsapp-takeover.py" + ], + "type": "stdio", + "env": {} + }, + "WhatsApp Server": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp[cli]", + "--with", + "requests", + "mcp", + "run", + "/Users/marcomilanta/Documents/invariant/mcp-injection-experiments/whatsapp.py" + ], + "type": "stdio", + "env": {} + } + } +} From 815dc146d9f719869fa7963aa46578407db21a3b Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 10:27:45 +0200 Subject: [PATCH 02/21] fix: add tests --- src/mcp_scan/gateway.py | 2 ++ tests/unit/test_gateway.py | 45 +++++++++++++++++++++++++++++++++++ tests/unit/test_mcp_client.py | 2 +- 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_gateway.py diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py index 985efcd..ba77a54 100644 --- a/src/mcp_scan/gateway.py +++ b/src/mcp_scan/gateway.py @@ -44,6 +44,8 @@ class MCPGatewayConfig(BaseModel): def is_invariant_installed(server: StdioServer) -> bool: if server.args is None: return False + if not server.args: + return False return server.args[0] == "invariant-gateway@latest" def install_gateway( diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py new file mode 100644 index 0000000..5ccd08e --- /dev/null +++ b/tests/unit/test_gateway.py @@ -0,0 +1,45 @@ +import pytest +from mcp_scan.gateway import is_invariant_installed, MCPGatewayInstaller, MCPGatewayConfig +from mcp_scan.MCPScanner import scan_config_file +from tests.unit.test_mcp_client import SAMPLE_CONFIGS +import pyjson5 +import tempfile +import os + + +@pytest.fixture +def temp_file(): + with tempfile.NamedTemporaryFile(delete=False) as tf: + yield tf.name + os.remove(tf.name) + +@pytest.mark.parametrize("server_config", SAMPLE_CONFIGS) +def test_install_gateway(server_config: str, temp_file): + with open(temp_file, "w") as f: + f.write(server_config) + + config_dict = pyjson5.loads(server_config) + installer = MCPGatewayInstaller(paths=[temp_file]) + for server in scan_config_file(temp_file).get_servers().values(): + print(f"{server=}") + assert not is_invariant_installed(server) + installer.install(gateway_config=MCPGatewayConfig( + project_name="test", + push_explorer=True, + api_key="my-very-secret-api-key" + ), verbose=True) + config_dict_installed = pyjson5.loads(server_config) + for server in scan_config_file(temp_file).get_servers().values(): + assert is_invariant_installed(server) + installer.uninstall(verbose=True) + mcp = scan_config_file(temp_file) + print(f"{mcp=}") + for server in scan_config_file(temp_file).get_servers().values(): + assert not is_invariant_installed(server) + + config_dict_uninstalled = pyjson5.loads(server_config) + + assert config_dict_uninstalled == config_dict + print(f"{config_dict=}") + print(f"{config_dict_installed=}") + print(f"{config_dict_uninstalled=}") \ No newline at end of file diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index 1284c90..dece0c3 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -64,7 +64,7 @@ vscode_config, ] -def test_scan_mcp_config(): +def test_scan_mcp_config_file(): for config in SAMPLE_CONFIGS: with tempfile.NamedTemporaryFile(mode="w") as temp_file: temp_file.write(config) From a118e5a64c00e63f7f3e91333542d23350cc95d4 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 10:30:43 +0200 Subject: [PATCH 03/21] fix: minor --- tests/unit/test_gateway.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py index 5ccd08e..545fb12 100644 --- a/tests/unit/test_gateway.py +++ b/tests/unit/test_gateway.py @@ -40,6 +40,3 @@ def test_install_gateway(server_config: str, temp_file): config_dict_uninstalled = pyjson5.loads(server_config) assert config_dict_uninstalled == config_dict - print(f"{config_dict=}") - print(f"{config_dict_installed=}") - print(f"{config_dict_uninstalled=}") \ No newline at end of file From 601e400034bd98bcbbc61088525fabec9e845c52 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 10:30:58 +0200 Subject: [PATCH 04/21] fix: minor --- tests/unit/test_gateway.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py index 545fb12..f15db11 100644 --- a/tests/unit/test_gateway.py +++ b/tests/unit/test_gateway.py @@ -21,7 +21,6 @@ def test_install_gateway(server_config: str, temp_file): config_dict = pyjson5.loads(server_config) installer = MCPGatewayInstaller(paths=[temp_file]) for server in scan_config_file(temp_file).get_servers().values(): - print(f"{server=}") assert not is_invariant_installed(server) installer.install(gateway_config=MCPGatewayConfig( project_name="test", @@ -33,7 +32,6 @@ def test_install_gateway(server_config: str, temp_file): assert is_invariant_installed(server) installer.uninstall(verbose=True) mcp = scan_config_file(temp_file) - print(f"{mcp=}") for server in scan_config_file(temp_file).get_servers().values(): assert not is_invariant_installed(server) From e0e8de5d7a146805faab983f4dde8377cbc681d3 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 13:30:18 +0200 Subject: [PATCH 05/21] fix: add test to use mcp client --- tests/mcp_servers/math.py | 34 +++++++++++++++++++++++++++++++ tests/mcp_servers/mcp_config.json | 4 ++++ 2 files changed, 38 insertions(+) create mode 100644 tests/mcp_servers/math.py diff --git a/tests/mcp_servers/math.py b/tests/mcp_servers/math.py new file mode 100644 index 0000000..a0ed124 --- /dev/null +++ b/tests/mcp_servers/math.py @@ -0,0 +1,34 @@ +from mcp.server.fastmcp import FastMCP + +# Create an MCP server +mcp = FastMCP("Demo") + + +# Add an addition tool +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + +# Add a subtraction tool +@mcp.tool() +def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b + +# Add a multiplication tool +@mcp.tool() +def multiply(a: int, b: int) -> int: + """Multiply two numbers""" + return a * b + +# Add a division tool +@mcp.tool() +def divide(a: int, b: int) -> int: + """Divide two numbers""" + if b == 0: + raise ValueError("Cannot divide by zero") + return a // b + +if __name__ == "__main__": + mcp.run() \ No newline at end of file diff --git a/tests/mcp_servers/mcp_config.json b/tests/mcp_servers/mcp_config.json index 599b7ac..39637e6 100644 --- a/tests/mcp_servers/mcp_config.json +++ b/tests/mcp_servers/mcp_config.json @@ -2,7 +2,11 @@ "mcpServers": { "Math": { "command": "python3", +<<<<<<< HEAD "args": ["tests/mcp_servers/math_server.py"] +======= + "args": ["tests/mcp_servers/math.py"] +>>>>>>> daaa3d9 (fix: add test to use mcp client) } } } \ No newline at end of file From 1d7f2b56d34e7805361eba3fab806150e37d561d Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 15:06:30 +0200 Subject: [PATCH 06/21] add mypy precommit hook --- src/mcp_scan/MCPScanner.py | 8 +++----- src/mcp_scan/gateway.py | 35 ++++++++--------------------------- src/mcp_scan/mcp_client.py | 1 + tests/unit/test_gateway.py | 26 ++++++++++++++++++-------- tests/unit/test_mcp_client.py | 2 +- 5 files changed, 31 insertions(+), 41 deletions(-) diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index 398159f..617969f 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -1,9 +1,4 @@ import os -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client -from mcp.client.sse import sse_client -from typing import Type - import json import textwrap import asyncio @@ -31,6 +26,9 @@ from hashlib import md5 import pyjson5 from .utils import rebalance_command_args +from .models import Result +from .StorageFile import StorageFile +from .verify_api import verify_server def format_err_str(e: Exception, max_length: int | None=None) -> str: diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py index ba77a54..55ea8cb 100644 --- a/src/mcp_scan/gateway.py +++ b/src/mcp_scan/gateway.py @@ -1,7 +1,9 @@ from mcp_scan.models import StdioServer, SSEServer, MCPConfig -from mcp_scan.MCPScanner import scan_config_file, format_path_line, format_servers_line +from mcp_scan.MCPScanner import format_path_line, format_servers_line +from mcp_scan.mcp_client import scan_mcp_config_file from pydantic import BaseModel from rich.tree import Tree +from rich.text import Text import argparse import json import os @@ -91,7 +93,7 @@ def uninstall_gateway( env=new_env, ) -def format_install_line(server: str, status: str, success: bool | None) -> Tree: +def format_install_line(server: str, status: str, success: bool | None) -> Text: color = {True: "[green]", False: "[red]", None: "[gray62]"}[success] if len(server) > 25: @@ -102,7 +104,7 @@ def format_install_line(server: str, status: str, success: bool | None) -> Tree: ] text = f"{color}[bold]{server}[/bold]{icon} {status}{color.replace('[', '[/')}" - return rich.text.Text.from_markup(text) + return Text.from_markup(text) class MCPGatewayInstaller: """ @@ -116,7 +118,7 @@ def install(self, gateway_config: MCPGatewayConfig, verbose: bool = False, ) -> for path in self.paths: config: MCPConfig | None = None try: - config = scan_config_file(path) + config = scan_mcp_config_file(path) status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" except FileNotFoundError: status = f"file does not exist" @@ -155,7 +157,7 @@ def uninstall(self, verbose: bool = False) -> None: for path in self.paths: config: MCPConfig | None = None try: - config = scan_config_file(path) + config = scan_mcp_config_file(path) status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" except FileNotFoundError: status = f"file does not exist" @@ -167,7 +169,7 @@ def uninstall(self, verbose: bool = False) -> None: continue path_print_tree = Tree("│") - config = scan_config_file(path) + config = scan_mcp_config_file(path) new_servers: dict[str, SSEServer | StdioServer] = {} for name, server in config.get_servers().items(): if isinstance(server, StdioServer): @@ -188,24 +190,3 @@ def uninstall(self, verbose: bool = False) -> None: rich.print(path_print_tree) with open(os.path.expanduser(path), "w") as f: f.write(config.model_dump_json(indent=4) + "\n") - -if __name__ == "__main__": - gateway_conf = MCPGatewayConfig( - project_name="test", - push_explorer=True, - api_key="inv-bb02e19170dd24aad2a13dceb17082a6d91b16597cb7bb579b9b777d78e54aaf", - ) - installer = MCPGatewayInstaller( - paths=["tests/test_configs.json"], - config=gateway_conf, - ) - confs = [scan_config_file(path) for path in installer.paths] - installer.install() - print("Installed") - installer.uninstall() - print("Uninstalled") - new_confs = [scan_config_file(path) for path in installer.paths] - for conf, new_conf in zip(confs, new_confs): - print(conf.model_dump_json(indent=4)) - print(new_conf.model_dump_json(indent=4)) - assert conf == new_conf \ No newline at end of file diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index a09bcbb..972f703 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -16,6 +16,7 @@ import pyjson5 import os from typing import Type, AsyncContextManager +from typing import Type async def check_server( server_config: SSEServer | StdioServer, timeout: int, suppress_mcpserver_io: bool diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py index f15db11..06592a5 100644 --- a/tests/unit/test_gateway.py +++ b/tests/unit/test_gateway.py @@ -1,6 +1,7 @@ import pytest from mcp_scan.gateway import is_invariant_installed, MCPGatewayInstaller, MCPGatewayConfig -from mcp_scan.MCPScanner import scan_config_file +from mcp_scan.MCPScanner import scan_mcp_config_file +from mcp_scan.models import StdioServer from tests.unit.test_mcp_client import SAMPLE_CONFIGS import pyjson5 import tempfile @@ -20,20 +21,29 @@ def test_install_gateway(server_config: str, temp_file): config_dict = pyjson5.loads(server_config) installer = MCPGatewayInstaller(paths=[temp_file]) - for server in scan_config_file(temp_file).get_servers().values(): - assert not is_invariant_installed(server) + for server in scan_mcp_config_file(temp_file).get_servers().values(): + if isinstance(server, StdioServer): + assert not is_invariant_installed(server) installer.install(gateway_config=MCPGatewayConfig( project_name="test", push_explorer=True, api_key="my-very-secret-api-key" ), verbose=True) + config_dict_installed = pyjson5.loads(server_config) - for server in scan_config_file(temp_file).get_servers().values(): - assert is_invariant_installed(server) + + for server in scan_mcp_config_file(temp_file).get_servers().values(): + if isinstance(server, StdioServer): + assert is_invariant_installed(server) + + installer.uninstall(verbose=True) - mcp = scan_config_file(temp_file) - for server in scan_config_file(temp_file).get_servers().values(): - assert not is_invariant_installed(server) + mcp = scan_mcp_config_file(temp_file) + + + for server in scan_mcp_config_file(temp_file).get_servers().values(): + if isinstance(server, StdioServer): + assert not is_invariant_installed(server) config_dict_uninstalled = pyjson5.loads(server_config) diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index dece0c3..1284c90 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -64,7 +64,7 @@ vscode_config, ] -def test_scan_mcp_config_file(): +def test_scan_mcp_config(): for config in SAMPLE_CONFIGS: with tempfile.NamedTemporaryFile(mode="w") as temp_file: temp_file.write(config) From b64e0921e78734564385995b9ef7c35ec9007978 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 17:23:54 +0200 Subject: [PATCH 07/21] fix: enforce type in large scale --- src/mcp_scan/MCPScanner.py | 6 ++++++ src/mcp_scan/mcp_client.py | 1 - src/mcp_scan/models.py | 7 ++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index 617969f..a61509f 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -29,6 +29,7 @@ from .models import Result from .StorageFile import StorageFile from .verify_api import verify_server +from typing import Literal def format_err_str(e: Exception, max_length: int | None=None) -> str: @@ -75,8 +76,13 @@ def format_servers_line(server: str, status: str | None=None) -> Text: return Text.from_markup(text) +<<<<<<< HEAD def format_entity_line( entity: Entity, +======= +def format_tool_line( + tool: Entity, +>>>>>>> 5b567ef (fix: enforce type in large scale) verified: Result, changed: Result = Result(), include_description: bool=False, diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index 972f703..a09bcbb 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -16,7 +16,6 @@ import pyjson5 import os from typing import Type, AsyncContextManager -from typing import Type async def check_server( server_config: SSEServer | StdioServer, timeout: int, suppress_mcpserver_io: bool diff --git a/src/mcp_scan/models.py b/src/mcp_scan/models.py index 8c1739f..ab17fe6 100644 --- a/src/mcp_scan/models.py +++ b/src/mcp_scan/models.py @@ -2,8 +2,13 @@ from typing import Any, Literal, TypeAlias, NamedTuple from datetime import datetime from mcp.types import Prompt, Resource, Tool +from typing import Any, Literal +from typing import Any +from typing import NamedTuple +from datetime import datetime +from mcp.types import Prompt, Resource, Tool -Entity: TypeAlias = Prompt | Resource | Tool +Entity = Prompt | Resource | Tool def entity_type_to_str(entity: Entity) -> str: if isinstance(entity, Prompt): From cb1695798dd1991190479d0fef865314bb138a7b Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Thu, 24 Apr 2025 17:59:24 +0200 Subject: [PATCH 08/21] fix: minor --- src/mcp_scan/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp_scan/models.py b/src/mcp_scan/models.py index ab17fe6..013e210 100644 --- a/src/mcp_scan/models.py +++ b/src/mcp_scan/models.py @@ -8,7 +8,7 @@ from datetime import datetime from mcp.types import Prompt, Resource, Tool -Entity = Prompt | Resource | Tool +Entity: TypeAlias = Prompt | Resource | Tool def entity_type_to_str(entity: Entity) -> str: if isinstance(entity, Prompt): From 30e825cf0330f816a33ffff31a749ea33a932656 Mon Sep 17 00:00:00 2001 From: Marco Milanta Date: Mon, 28 Apr 2025 10:14:01 +0200 Subject: [PATCH 09/21] fix: add assert errors --- tests/unit/test_gateway.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py index 222a38b..a2f52d2 100644 --- a/tests/unit/test_gateway.py +++ b/tests/unit/test_gateway.py @@ -26,7 +26,7 @@ def test_install_gateway(server_config: str, temp_file): installer = MCPGatewayInstaller(paths=[temp_file]) for server in scan_mcp_config_file(temp_file).get_servers().values(): if isinstance(server, StdioServer): - assert not is_invariant_installed(server) + assert not is_invariant_installed(server), "Invariant should not be installed" installer.install( gateway_config=MCPGatewayConfig(project_name="test", push_explorer=True, api_key="my-very-secret-api-key"), verbose=True, @@ -37,14 +37,14 @@ def test_install_gateway(server_config: str, temp_file): for server in scan_mcp_config_file(temp_file).get_servers().values(): if isinstance(server, StdioServer): - assert is_invariant_installed(server) + assert is_invariant_installed(server), "Invariant should be installed" installer.uninstall(verbose=True) for server in scan_mcp_config_file(temp_file).get_servers().values(): if isinstance(server, StdioServer): - assert not is_invariant_installed(server) + assert not is_invariant_installed(server), "Invariant should be uninstalled" config_dict_uninstalled = pyjson5.loads(server_config) - assert config_dict_uninstalled == config_dict + assert config_dict_uninstalled == config_dict, "Installation and uninstallation of the gateway should not change the config file" From 221d97628b007e38c689b415bdf7a2f7beda6e94 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Mon, 28 Apr 2025 16:27:49 +0200 Subject: [PATCH 10/21] add mcp scan server --- pyproject.toml | 5 + src/mcp_scan/StorageFile.py | 52 +++++++-- src/mcp_scan/cli.py | 18 +++ src/mcp_scan_server/models.py | 107 ++++++++++++++++++ src/mcp_scan_server/routes/policies.py | 146 +++++++++++++++++++++++++ src/mcp_scan_server/routes/push.py | 15 +++ src/mcp_scan_server/routes/trace.py | 11 ++ src/mcp_scan_server/routes/user.py | 11 ++ src/mcp_scan_server/server.py | 37 +++++++ 9 files changed, 394 insertions(+), 8 deletions(-) create mode 100644 src/mcp_scan_server/models.py create mode 100644 src/mcp_scan_server/routes/policies.py create mode 100644 src/mcp_scan_server/routes/push.py create mode 100644 src/mcp_scan_server/routes/trace.py create mode 100644 src/mcp_scan_server/routes/user.py create mode 100644 src/mcp_scan_server/server.py diff --git a/pyproject.toml b/pyproject.toml index b8d7748..a6cd9db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,11 @@ dependencies = [ "pydantic>=2.11.2", "lark-parser[regex]>=0.12.0", "psutil>=5.9.0", + "invariant-ai>=0.3", + "fastapi>=0.115.12", + "uvicorn>=0.34.2", + "invariant-sdk>=0.0.11", + "pyyaml>=6.0.2", ] [project.scripts] diff --git a/src/mcp_scan/StorageFile.py b/src/mcp_scan/StorageFile.py index 3346f50..d523d43 100644 --- a/src/mcp_scan/StorageFile.py +++ b/src/mcp_scan/StorageFile.py @@ -9,19 +9,22 @@ from .models import Entity, Result, ScannedEntities, ScannedEntity, entity_type_to_str from .utils import upload_whitelist_entry - +from mcp_scan_server.models import GuardrailConfig +import yaml class StorageFile: def __init__(self, path: str): self.path = os.path.expanduser(path) + # if path is a file self.scanned_entities: ScannedEntities = ScannedEntities({}) self.whitelist: dict[str, str] = {} + self.guardrails_config: GuardrailConfig = GuardrailConfig({}) - if os.path.isfile(path): - rich.print(f"[bold]Legacy storage file detected at {path}, converting to new format[/bold]") + if os.path.isfile(self.path): + rich.print(f"[bold]Legacy storage file detected at {self.path}, converting to new format[/bold]") # legacy format - with open(path, "r") as f: + with open(self.path, "r") as f: legacy_data = json.load(f) if "__whitelist" in legacy_data: self.whitelist = legacy_data["__whitelist"] @@ -32,8 +35,9 @@ def __init__(self, path: str): rich.print(f"[bold red]Could not load legacy storage file {self.path}: {e}[/bold red]") os.remove(path) - if os.path.exists(path) and os.path.isdir(path): - scanned_entities_path = os.path.join(path, "scanned_entities.json") + print(path, os.path.exists(path), os.path.isdir(path)) + if os.path.exists(self.path) and os.path.isdir(self.path): + scanned_entities_path = os.path.join(self.path, "scanned_entities.json") if os.path.exists(scanned_entities_path): with open(scanned_entities_path, "r") as f: try: @@ -42,9 +46,26 @@ def __init__(self, path: str): rich.print( f"[bold red]Could not load scanned entities file {scanned_entities_path}: {e}[/bold red]" ) - if os.path.exists(os.path.join(path, "whitelist.json")): - with open(os.path.join(path, "whitelist.json"), "r") as f: + if os.path.exists(os.path.join(self.path, "whitelist.json")): + with open(os.path.join(self.path, "whitelist.json"), "r") as f: self.whitelist = json.load(f) + + guardrails_config_path = os.path.join(self.path, "guardrails_config.yml") + print(guardrails_config_path) + if os.path.exists(guardrails_config_path): + print("Reading guardrails config") + with open(guardrails_config_path, "r") as f: + try: + guardrails_config_data = yaml.safe_load(f.read()) or {} + self.guardrails_config = GuardrailConfig.model_validate(guardrails_config_data) + except yaml.YAMLError as e: + rich.print( + f"[bold red]Could not parse guardrails config file {guardrails_config_path}: {e}[/bold red]" + ) + except ValidationError as e: + rich.print( + f"[bold red]Could not validate guardrails config file {guardrails_config_path}: {e}[/bold red]" + ) def reset_whitelist(self) -> None: self.whitelist = {} @@ -101,10 +122,25 @@ def add_to_whitelist(self, entity_type: str, name: str, hash: str, base_url: str def is_whitelisted(self, entity: Entity) -> bool: hash = self.compute_hash(entity) return hash in self.whitelist.values() + + def create_guardrails_config(self) -> str: + """ + If the guardrails config file does not exist, create it with default values. + Returns the path to the guardrails config file. + """ + guardrails_config_path = os.path.join(self.path, "guardrails_config.yml") + if not os.path.exists(guardrails_config_path): + with open(guardrails_config_path, "w") as f: + if self.guardrails_config is not None: + f.write(self.guardrails_config) + return guardrails_config_path + def save(self) -> None: os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "scanned_entities.json"), "w") as f: f.write(self.scanned_entities.model_dump_json()) with open(os.path.join(self.path, "whitelist.json"), "w") as f: json.dump(self.whitelist, f) + with open(os.path.join(self.path, "guardrails_config.yml"), "w") as f: + f.write(self.guardrails_config) diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 82feb7a..298dd1d 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -5,6 +5,7 @@ import rich from mcp_scan.gateway import MCPGatewayConfig, MCPGatewayInstaller +from mcp_scan_server.server import MCPScanServer from .MCPScanner import MCPScanner from .StorageFile import StorageFile @@ -266,6 +267,17 @@ def main(): description="Display detailed help information and examples.", ) + # SERVER command + server_parser = subparsers.add_parser("server", help="Start the MCP scan server") + server_parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + metavar="PORT", + ) + add_common_arguments(server_parser) + # Display version banner rich.print(f"[bold blue]Invariant MCP-scan v{version_info}[/bold blue]\n") @@ -333,6 +345,12 @@ def main(): elif args.command == "scan" or args.command is None: # default to scan MCPScanner(**vars(args)).start() sys.exit(0) + elif args.command == "server": + sf = StorageFile(args.storage_file) + guardrails_config_path = sf.create_guardrails_config() + mcp_scan_server = MCPScanServer(port=args.port, config_file_path=guardrails_config_path) + mcp_scan_server.run() + sys.exit(0) else: # This shouldn't happen due to argparse's handling rich.print(f"[bold red]Unknown command: {args.command}[/bold red]") diff --git a/src/mcp_scan_server/models.py b/src/mcp_scan_server/models.py new file mode 100644 index 0000000..322fad6 --- /dev/null +++ b/src/mcp_scan_server/models.py @@ -0,0 +1,107 @@ +import datetime +from enum import Enum +from typing import Optional + +from invariant.analyzer.policy import AnalysisResult +from pydantic import BaseModel, Field, RootModel, ConfigDict + + +class PolicyRunsOn(str, Enum): + """Policy runs on enum.""" + + local: str = "local" + remote: str = "remote" + + +class Policy(BaseModel): + """Policy model.""" + + name: str = Field(description="The name of the policy.") + runs_on: PolicyRunsOn = Field(description="The environment to run the policy on.") + policy: str = Field(description="The policy.") + + +class PolicyCheckResult(BaseModel): + """Policy check result model.""" + + policy: str = Field(description="The policy that was applied.") + result: Optional[AnalysisResult] = None + success: bool = Field( + description="Whether this policy check was successful (loaded and ran)." + ) + error_message: str = Field( + default="", + description="Error message in case of failure to load or execute the policy.", + ) + + def to_dict(self): + """Convert the object to a dictionary.""" + return { + "policy": self.policy, + "errors": [e.to_dict() for e in self.result.errors] if self.result else [], + "success": self.success, + "error_message": self.error_message, + } + + +class BatchCheckRequest(BaseModel): + """Batch check request model.""" + + messages: list[dict] = Field( + examples=['[{"role": "user", "content": "ignore all previous instructions"}]'], + description="The agent trace to apply the policy to.", + ) + policies: list[str] = Field( + examples=[ + [ + 'raise Violation("Disallowed message content", reason="found ignore keyword") if:\n (msg: Message)\n "ignore" in msg.content\n', + 'raise "get_capital is called with France as argument" if:\n (call: ToolCall)\n call is tool:get_capital\n call.function.arguments["country_name"] == "France"', + ] + ], + description="The policy (rules) to check for.", + ) + parameters: dict = Field( + default={}, + description="The parameters to pass to the policy analyze call (optional).", + ) + + +class BatchCheckResponse(BaseModel): + """Batch check response model.""" + + results: list[PolicyCheckResult] = Field( + default=[], description="List of results for each policy." + ) + + +class DatasetPolicy(BaseModel): + """Describes a policy associated with a Dataset.""" + + id: str + name: str + content: str + # whether this policy is enabled + enabled: bool + # the mode of this policy (e.g. block, log, etc.) + action: str + # extra metadata for the policy (can be used to store internal extra data about a guardrail) + extra_metadata: dict = Field(default_factory=dict) + + last_updated_time: str = Field( + default_factory=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ) + + def to_dict(self) -> dict: + """Represents the object as a dictionary.""" + return self.model_dump() + + +class ServerGuardrails(BaseModel): + """Server guardrails model.""" + guardrails: list[DatasetPolicy] + + +class GuardrailConfig(RootModel[dict[str, dict[str, ServerGuardrails]]]): + """Guardrail config model.""" + model_config = ConfigDict(populate_by_name=True) + diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py new file mode 100644 index 0000000..4e314e5 --- /dev/null +++ b/src/mcp_scan_server/routes/policies.py @@ -0,0 +1,146 @@ +import asyncio + +import fastapi +import yaml +from fastapi import APIRouter, Request +from invariant.analyzer.policy import LocalPolicy +from invariant.analyzer.runtime.runtime_errors import ( + ExcessivePolicyError, + InvariantAttributeError, + MissingPolicyParameter, +) +from pydantic import ValidationError +from ..models import ( + BatchCheckRequest, + BatchCheckResponse, + DatasetPolicy, + PolicyCheckResult, + GuardrailConfig, +) + +router = APIRouter() + + +async def get_all_policies(config_file_path: str) -> list[DatasetPolicy]: + """ + Get all policies from local config file. + """ + with open(config_file_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + try: + config = GuardrailConfig.model_validate(config) + except ValidationError as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) + + policies = [ + DatasetPolicy( + id=guardrail.id, + name=guardrail.name, + content=guardrail.content, + enabled=guardrail.enabled, + action=guardrail.action, + extra_metadata={ + "platform": platform_name, + "tool": tool_name, + }, + ) + for platform_name, platform_data in config.root.items() + for tool_name, server_guardrails in platform_data.items() + for guardrail in server_guardrails.guardrails + ] + + return policies + + +@router.get("/dataset/byuser/{username}/{dataset_name}/policy") +async def get_policy(username: str, dataset_name: str, request: Request): + """ + Get a policy from local config file. + """ + policies = await get_all_policies(request.app.state.config_file_path) + return {"policies": policies} + + +async def check_policy( + policy_str: str, messages: list[dict], parameters: dict = {} +) -> PolicyCheckResult: + """ + Check a policy using the invariant analyzer. + + Args: + policy_str: The policy to check. + messages: The messages to check the policy against. + parameters: The parameters to pass to the policy analyze call. + + Returns: + A PolicyCheckResult object. + """ + try: + policy = LocalPolicy.from_string(policy_str) + + if isinstance(policy, Exception): + return PolicyCheckResult( + policy=policy_str, + success=False, + error_message=str(policy), + ) + + result = await policy.a_analyze(messages, **parameters) + + return PolicyCheckResult( + policy=policy_str, + result=result, + success=True, + ) + + except (MissingPolicyParameter, ExcessivePolicyError, InvariantAttributeError) as e: + return PolicyCheckResult( + policy=policy_str, + success=False, + error_message=str(e), + ) + except Exception as e: + return PolicyCheckResult( + policy=policy_str, + success=False, + error_message="Unexpected error: " + str(e), + ) + + +def to_json_serializable_dict(obj): + """ + Converts a dictionary to a JSON serializable dictionary. + """ + if isinstance(obj, dict): + return {k: to_json_serializable_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [to_json_serializable_dict(v) for v in obj] + elif isinstance(obj, str): + return obj + elif isinstance(obj, (int, float, bool)): + return obj + else: + return type(obj).__name__ + "(" + str(obj) + ")" + + +@router.post("/policy/check/batch", response_model=BatchCheckResponse) +async def batch_check_policies(request: BatchCheckRequest): + """ + Check a policy using the invariant analyzer. + """ + results = await asyncio.gather( + *[ + check_policy(policy, request.messages, request.parameters) + for policy in request.policies + ] + ) + + return fastapi.responses.JSONResponse( + content={ + "result": [ + to_json_serializable_dict(result.to_dict()) for result in results + ] + } + ) diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py new file mode 100644 index 0000000..82b5ca8 --- /dev/null +++ b/src/mcp_scan_server/routes/push.py @@ -0,0 +1,15 @@ +import uuid + +from fastapi import APIRouter +from invariant_sdk.types.push_traces import PushTracesResponse + +router = APIRouter() + + +@router.post("/trace") +async def push_trace(): + """ + Push a trace. For now, this is a dummy response. + """ + return PushTracesResponse(id=[str(uuid.uuid4())], success=True) + diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py new file mode 100644 index 0000000..4e5072c --- /dev/null +++ b/src/mcp_scan_server/routes/trace.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter, Request + +router = APIRouter() + + +@router.post("/{trace_id}/messages") +async def append_messages(request: Request): + """ + Append messages to a trace. For now this is a NoOp. + """ + return {"success": True} diff --git a/src/mcp_scan_server/routes/user.py b/src/mcp_scan_server/routes/user.py new file mode 100644 index 0000000..523ee12 --- /dev/null +++ b/src/mcp_scan_server/routes/user.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +router = APIRouter() + + +@router.get("/identity") +async def identity(): + """ + Get the identity of the user. For now, this is a dummy response. + """ + return {"username": "user"} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py new file mode 100644 index 0000000..0a8d223 --- /dev/null +++ b/src/mcp_scan_server/server.py @@ -0,0 +1,37 @@ +import os + +import uvicorn +from fastapi import FastAPI + +from .routes.policies import router as policies_router +from .routes.push import router as push_router +from .routes.trace import router as dataset_trace_router +from .routes.user import router as user_router + + +class MCPScanServer: + """ + MCP Scan Server + + Args: + port: The port to run the server on. + config_file_path: The path to the config file. + """ + def __init__(self, port: int = 8000, config_file_path: str = None): + self.port = port + self.config_file_path = config_file_path + + self.app = FastAPI() + self.app.state.config_file_path = config_file_path + + self.app.include_router(policies_router, prefix="/api/v1") + self.app.include_router(push_router, prefix="/api/v1/push") + self.app.include_router(dataset_trace_router, prefix="/api/v1/trace") + self.app.include_router(user_router, prefix="/api/v1/user") + + def run(self): + """ + Run the MCP scan server. + """ + uvicorn.run(self.app, host="0.0.0.0", port=self.port) + From 2025d7debce667935ece631cd30fda4aa7fb234b Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Mon, 28 Apr 2025 16:58:38 +0200 Subject: [PATCH 11/21] fix test --- src/mcp_scan/StorageFile.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mcp_scan/StorageFile.py b/src/mcp_scan/StorageFile.py index d523d43..5937ff6 100644 --- a/src/mcp_scan/StorageFile.py +++ b/src/mcp_scan/StorageFile.py @@ -142,5 +142,3 @@ def save(self) -> None: f.write(self.scanned_entities.model_dump_json()) with open(os.path.join(self.path, "whitelist.json"), "w") as f: json.dump(self.whitelist, f) - with open(os.path.join(self.path, "guardrails_config.yml"), "w") as f: - f.write(self.guardrails_config) From ebe9d4feec3ece6826cd47c6ff38aa9c576f3237 Mon Sep 17 00:00:00 2001 From: Marc Fischer Date: Tue, 29 Apr 2025 15:52:41 +0200 Subject: [PATCH 12/21] dependency fix for lark --- Makefile | 5 +++++ pyproject.toml | 4 ++-- src/mcp_scan/utils.py | 1 - 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index f6dda85..62f4157 100644 --- a/Makefile +++ b/Makefile @@ -42,3 +42,8 @@ publish: publish-pypi publish-npm pre-commit: pre-commit run --all-files + +reset-uv: + rm -rf .venv || true + rm uv.lock || true + uv venv diff --git a/pyproject.toml b/pyproject.toml index a6cd9db..76c02e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "rich>=14.0.0", "pyjson5>=1.6.8", "pydantic>=2.11.2", - "lark-parser[regex]>=0.12.0", + "lark>=1.1.9", "psutil>=5.9.0", "invariant-ai>=0.3", "fastapi>=0.115.12", @@ -56,4 +56,4 @@ include = '\.pyi?$' [tool.isort] profile = "black" -line_length = 120 +line_length = 120 \ No newline at end of file diff --git a/src/mcp_scan/utils.py b/src/mcp_scan/utils.py index 9868a02..69a24a3 100644 --- a/src/mcp_scan/utils.py +++ b/src/mcp_scan/utils.py @@ -20,7 +20,6 @@ def rebalance_command_args(command, args): %ignore WS """, parser="lalr", - lexer="standard", start="command", regex=True, ) From 7ded7b277319ee96f966846a5fcf611f66417410 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 29 Apr 2025 16:54:11 +0200 Subject: [PATCH 13/21] fix: pre-commit hooks --- src/mcp_scan/MCPScanner.py | 1 + src/mcp_scan/StorageFile.py | 25 ++++++++------ src/mcp_scan_server/__init__.py | 0 src/mcp_scan_server/models.py | 35 +++++++++++-------- src/mcp_scan_server/routes/__init__.py | 0 src/mcp_scan_server/routes/policies.py | 47 ++++++++++++-------------- src/mcp_scan_server/routes/push.py | 5 +-- src/mcp_scan_server/routes/trace.py | 4 +-- src/mcp_scan_server/routes/user.py | 4 +-- src/mcp_scan_server/server.py | 12 +++---- 10 files changed, 63 insertions(+), 70 deletions(-) create mode 100644 src/mcp_scan_server/__init__.py create mode 100644 src/mcp_scan_server/routes/__init__.py diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index 6860e13..8ba5c77 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -4,6 +4,7 @@ from typing import Any import rich +from exceptiongroup import ExceptionGroup from rich.text import Text from rich.tree import Tree diff --git a/src/mcp_scan/StorageFile.py b/src/mcp_scan/StorageFile.py index 5937ff6..5d787c7 100644 --- a/src/mcp_scan/StorageFile.py +++ b/src/mcp_scan/StorageFile.py @@ -5,12 +5,14 @@ from typing import Any import rich +import yaml # type: ignore from pydantic import ValidationError +from mcp_scan_server.models import GuardrailConfig + from .models import Entity, Result, ScannedEntities, ScannedEntity, entity_type_to_str from .utils import upload_whitelist_entry -from mcp_scan_server.models import GuardrailConfig -import yaml + class StorageFile: def __init__(self, path: str): @@ -35,7 +37,6 @@ def __init__(self, path: str): rich.print(f"[bold red]Could not load legacy storage file {self.path}: {e}[/bold red]") os.remove(path) - print(path, os.path.exists(path), os.path.isdir(path)) if os.path.exists(self.path) and os.path.isdir(self.path): scanned_entities_path = os.path.join(self.path, "scanned_entities.json") if os.path.exists(scanned_entities_path): @@ -49,22 +50,22 @@ def __init__(self, path: str): if os.path.exists(os.path.join(self.path, "whitelist.json")): with open(os.path.join(self.path, "whitelist.json"), "r") as f: self.whitelist = json.load(f) - + guardrails_config_path = os.path.join(self.path, "guardrails_config.yml") - print(guardrails_config_path) if os.path.exists(guardrails_config_path): - print("Reading guardrails config") with open(guardrails_config_path, "r") as f: try: guardrails_config_data = yaml.safe_load(f.read()) or {} self.guardrails_config = GuardrailConfig.model_validate(guardrails_config_data) except yaml.YAMLError as e: rich.print( - f"[bold red]Could not parse guardrails config file {guardrails_config_path}: {e}[/bold red]" + f"[bold red]Could not parse guardrails config file " + f"{guardrails_config_path}: {e}[/bold red]" ) except ValidationError as e: rich.print( - f"[bold red]Could not validate guardrails config file {guardrails_config_path}: {e}[/bold red]" + f"[bold red]Could not validate guardrails config file " + f"{guardrails_config_path}: {e}[/bold red]" ) def reset_whitelist(self) -> None: @@ -122,7 +123,7 @@ def add_to_whitelist(self, entity_type: str, name: str, hash: str, base_url: str def is_whitelisted(self, entity: Entity) -> bool: hash = self.compute_hash(entity) return hash in self.whitelist.values() - + def create_guardrails_config(self) -> str: """ If the guardrails config file does not exist, create it with default values. @@ -133,12 +134,14 @@ def create_guardrails_config(self) -> str: if not os.path.exists(guardrails_config_path): with open(guardrails_config_path, "w") as f: if self.guardrails_config is not None: - f.write(self.guardrails_config) + f.write(self.guardrails_config.model_dump_yaml()) return guardrails_config_path - + def save(self) -> None: os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "scanned_entities.json"), "w") as f: f.write(self.scanned_entities.model_dump_json()) with open(os.path.join(self.path, "whitelist.json"), "w") as f: json.dump(self.whitelist, f) + with open(os.path.join(self.path, "guardrails_config.yml"), "w") as f: + f.write(self.guardrails_config.model_dump_yaml()) diff --git a/src/mcp_scan_server/__init__.py b/src/mcp_scan_server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mcp_scan_server/models.py b/src/mcp_scan_server/models.py index 322fad6..d954873 100644 --- a/src/mcp_scan_server/models.py +++ b/src/mcp_scan_server/models.py @@ -2,15 +2,16 @@ from enum import Enum from typing import Optional +import yaml # type: ignore from invariant.analyzer.policy import AnalysisResult -from pydantic import BaseModel, Field, RootModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, RootModel class PolicyRunsOn(str, Enum): """Policy runs on enum.""" - local: str = "local" - remote: str = "remote" + local = "local" + remote = "remote" class Policy(BaseModel): @@ -26,9 +27,7 @@ class PolicyCheckResult(BaseModel): policy: str = Field(description="The policy that was applied.") result: Optional[AnalysisResult] = None - success: bool = Field( - description="Whether this policy check was successful (loaded and ran)." - ) + success: bool = Field(description="Whether this policy check was successful (loaded and ran).") error_message: str = Field( default="", description="Error message in case of failure to load or execute the policy.", @@ -54,8 +53,12 @@ class BatchCheckRequest(BaseModel): policies: list[str] = Field( examples=[ [ - 'raise Violation("Disallowed message content", reason="found ignore keyword") if:\n (msg: Message)\n "ignore" in msg.content\n', - 'raise "get_capital is called with France as argument" if:\n (call: ToolCall)\n call is tool:get_capital\n call.function.arguments["country_name"] == "France"', + """raise Violation("Disallowed message content", reason="found ignore keyword") if:\n + (msg: Message)\n "ignore" in msg.content\n""", + """raise "get_capital is called with France as argument" if:\n + (call: ToolCall)\n call is tool:get_capital\n + call.function.arguments["country_name"] == "France" + """, ] ], description="The policy (rules) to check for.", @@ -69,9 +72,7 @@ class BatchCheckRequest(BaseModel): class BatchCheckResponse(BaseModel): """Batch check response model.""" - results: list[PolicyCheckResult] = Field( - default=[], description="List of results for each policy." - ) + results: list[PolicyCheckResult] = Field(default=[], description="List of results for each policy.") class DatasetPolicy(BaseModel): @@ -87,21 +88,25 @@ class DatasetPolicy(BaseModel): # extra metadata for the policy (can be used to store internal extra data about a guardrail) extra_metadata: dict = Field(default_factory=dict) - last_updated_time: str = Field( - default_factory=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - ) + last_updated_time: str = Field(default_factory=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) def to_dict(self) -> dict: - """Represents the object as a dictionary.""" + """Represent the object as a dictionary.""" return self.model_dump() class ServerGuardrails(BaseModel): """Server guardrails model.""" + guardrails: list[DatasetPolicy] class GuardrailConfig(RootModel[dict[str, dict[str, ServerGuardrails]]]): """Guardrail config model.""" + model_config = ConfigDict(populate_by_name=True) + def model_dump_yaml(self) -> str: + """Dump the object as a YAML string.""" + data = self.model_dump() + return yaml.dump(data, sort_keys=False, default_flow_style=False) diff --git a/src/mcp_scan_server/routes/__init__.py b/src/mcp_scan_server/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py index 4e314e5..2825499 100644 --- a/src/mcp_scan_server/routes/policies.py +++ b/src/mcp_scan_server/routes/policies.py @@ -1,7 +1,9 @@ import asyncio +import os import fastapi -import yaml +import rich +import yaml # type: ignore from fastapi import APIRouter, Request from invariant.analyzer.policy import LocalPolicy from invariant.analyzer.runtime.runtime_errors import ( @@ -10,21 +12,29 @@ MissingPolicyParameter, ) from pydantic import ValidationError + from ..models import ( BatchCheckRequest, BatchCheckResponse, DatasetPolicy, - PolicyCheckResult, GuardrailConfig, + PolicyCheckResult, ) router = APIRouter() async def get_all_policies(config_file_path: str) -> list[DatasetPolicy]: - """ - Get all policies from local config file. - """ + """Get all policies from local config file.""" + if not os.path.exists(config_file_path): + # Format this as multiple lines without printing it like this + rich.print( + f"""[bold red]Guardrail config file not found: {config_file_path}. Creating an empty one.[/bold red]""" + ) + config = GuardrailConfig.model_validate({}) + with open(config_file_path, "w") as f: + f.write(config.model_dump_yaml()) + with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) try: @@ -56,16 +66,12 @@ async def get_all_policies(config_file_path: str) -> list[DatasetPolicy]: @router.get("/dataset/byuser/{username}/{dataset_name}/policy") async def get_policy(username: str, dataset_name: str, request: Request): - """ - Get a policy from local config file. - """ + """Get a policy from local config file.""" policies = await get_all_policies(request.app.state.config_file_path) return {"policies": policies} -async def check_policy( - policy_str: str, messages: list[dict], parameters: dict = {} -) -> PolicyCheckResult: +async def check_policy(policy_str: str, messages: list[dict], parameters: dict = {}) -> PolicyCheckResult: """ Check a policy using the invariant analyzer. @@ -110,9 +116,7 @@ async def check_policy( def to_json_serializable_dict(obj): - """ - Converts a dictionary to a JSON serializable dictionary. - """ + """Convert a dictionary to a JSON serializable dictionary.""" if isinstance(obj, dict): return {k: to_json_serializable_dict(v) for k, v in obj.items()} elif isinstance(obj, list): @@ -127,20 +131,11 @@ def to_json_serializable_dict(obj): @router.post("/policy/check/batch", response_model=BatchCheckResponse) async def batch_check_policies(request: BatchCheckRequest): - """ - Check a policy using the invariant analyzer. - """ + """Check a policy using the invariant analyzer.""" results = await asyncio.gather( - *[ - check_policy(policy, request.messages, request.parameters) - for policy in request.policies - ] + *[check_policy(policy, request.messages, request.parameters) for policy in request.policies] ) return fastapi.responses.JSONResponse( - content={ - "result": [ - to_json_serializable_dict(result.to_dict()) for result in results - ] - } + content={"result": [to_json_serializable_dict(result.to_dict()) for result in results]} ) diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index 82b5ca8..2da75ba 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -8,8 +8,5 @@ @router.post("/trace") async def push_trace(): - """ - Push a trace. For now, this is a dummy response. - """ + """Push a trace. For now, this is a dummy response.""" return PushTracesResponse(id=[str(uuid.uuid4())], success=True) - diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index 4e5072c..a64c646 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -5,7 +5,5 @@ @router.post("/{trace_id}/messages") async def append_messages(request: Request): - """ - Append messages to a trace. For now this is a NoOp. - """ + """Append messages to a trace. For now this is a dummy response.""" return {"success": True} diff --git a/src/mcp_scan_server/routes/user.py b/src/mcp_scan_server/routes/user.py index 523ee12..3e16cce 100644 --- a/src/mcp_scan_server/routes/user.py +++ b/src/mcp_scan_server/routes/user.py @@ -5,7 +5,5 @@ @router.get("/identity") async def identity(): - """ - Get the identity of the user. For now, this is a dummy response. - """ + """Get the identity of the user. For now, this is a dummy response.""" return {"username": "user"} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index 0a8d223..fa61822 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -1,5 +1,3 @@ -import os - import uvicorn from fastapi import FastAPI @@ -11,13 +9,14 @@ class MCPScanServer: """ - MCP Scan Server + MCP Scan Server. Args: port: The port to run the server on. config_file_path: The path to the config file. """ - def __init__(self, port: int = 8000, config_file_path: str = None): + + def __init__(self, port: int = 8000, config_file_path: str | None = None): self.port = port self.config_file_path = config_file_path @@ -30,8 +29,5 @@ def __init__(self, port: int = 8000, config_file_path: str = None): self.app.include_router(user_router, prefix="/api/v1/user") def run(self): - """ - Run the MCP scan server. - """ + """Run the MCP scan server.""" uvicorn.run(self.app, host="0.0.0.0", port=self.port) - From 42049fa92a97620abe176b60bd6e77595ccb2f52 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Wed, 30 Apr 2025 15:05:21 +0200 Subject: [PATCH 14/21] add tests --- src/mcp_scan_server/routes/policies.py | 3 +- tests/unit/test_gateway.py | 4 +- tests/unit/test_mcp_scan_server.py | 215 +++++++++++++++++++++++++ 3 files changed, 219 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_mcp_scan_server.py diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py index 2825499..69fb25e 100644 --- a/src/mcp_scan_server/routes/policies.py +++ b/src/mcp_scan_server/routes/policies.py @@ -27,7 +27,6 @@ async def get_all_policies(config_file_path: str) -> list[DatasetPolicy]: """Get all policies from local config file.""" if not os.path.exists(config_file_path): - # Format this as multiple lines without printing it like this rich.print( f"""[bold red]Guardrail config file not found: {config_file_path}. Creating an empty one.[/bold red]""" ) @@ -93,7 +92,7 @@ async def check_policy(policy_str: str, messages: list[dict], parameters: dict = error_message=str(policy), ) - result = await policy.a_analyze(messages, **parameters) + result = await policy.a_analyze_pending(messages[:-1], [messages[-1]], **parameters) return PolicyCheckResult( policy=policy_str, diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py index a2f52d2..64f9994 100644 --- a/tests/unit/test_gateway.py +++ b/tests/unit/test_gateway.py @@ -47,4 +47,6 @@ def test_install_gateway(server_config: str, temp_file): config_dict_uninstalled = pyjson5.loads(server_config) - assert config_dict_uninstalled == config_dict, "Installation and uninstallation of the gateway should not change the config file" + assert ( + config_dict_uninstalled == config_dict + ), "Installation and uninstallation of the gateway should not change the config file" diff --git a/tests/unit/test_mcp_scan_server.py b/tests/unit/test_mcp_scan_server.py new file mode 100644 index 0000000..c05b036 --- /dev/null +++ b/tests/unit/test_mcp_scan_server.py @@ -0,0 +1,215 @@ +import os +from unittest.mock import patch + +import pytest +import yaml # type: ignore +from fastapi import HTTPException +from fastapi.testclient import TestClient + +from mcp_scan_server.models import DatasetPolicy, GuardrailConfig +from mcp_scan_server.routes.policies import check_policy, get_all_policies +from mcp_scan_server.server import MCPScanServer + +client = TestClient(MCPScanServer().app) + + +@pytest.fixture +def valid_guardrail_config_file(tmp_path): + """Fixture that creates a temporary valid config file and returns its path.""" + config_file = tmp_path / "config.yaml" + config_file.write_text( + """ +cursor: + browsermcp: + guardrails: + - name: "Guardrail 1" + id: "guardrail_1" + runs-on: "local" + enabled: true + action: "block" + content: | + raise "error" if: + (msg: ToolOutput) + "Test1" in msg.content + - name: "Guardrail 2" + id: "guardrail_2" + runs-on: "local" + enabled: true + action: "block" + content: | + raise "error" if: + (msg: ToolOutput) + "Test2" in msg.content +""" + ) + return str(config_file) + + +@pytest.fixture +def invalid_guardrail_config_file(tmp_path): + """Fixture that creates a temporary invalid config file and returns its path.""" + config_file = tmp_path / "config.yaml" + config_file.write_text( + """ +cursor: + browsermcp: + guardrails: + - name: "Guardrail 1" + id: "guardrail_1" + runs-on: "local" + enabled: true + action: "block" +""" + ) + return str(config_file) + + +@pytest.mark.asyncio +async def test_get_all_policies_valid_config(valid_guardrail_config_file): + """Test that the get_all_policies function returns the correct policies for a valid config file.""" + policies = await get_all_policies(valid_guardrail_config_file) + print(policies) + assert len(policies) == 2 + assert all(isinstance(policy, DatasetPolicy) for policy in policies) + assert policies[0].id == "guardrail_1" + assert policies[1].id == "guardrail_2" + assert policies[0].name == "Guardrail 1" + assert policies[1].name == "Guardrail 2" + + +@pytest.mark.asyncio +async def test_get_all_policies_invalid_config(invalid_guardrail_config_file): + """Test that the get_all_policies function raises an HTTPException for an invalid config file.""" + with pytest.raises(HTTPException): + await get_all_policies(invalid_guardrail_config_file) + + +@pytest.mark.asyncio +async def test_get_all_policies_creates_file_when_missing(tmp_path): + """Test that get_all_policies creates a config file if it doesn't exist.""" + # Create a path to a non-existent file + config_file_path = str(tmp_path / "nonexistent_config.yaml") + + # Verify the file doesn't exist before calling the function + assert not os.path.exists(config_file_path) + + # Call the function + await get_all_policies(config_file_path) + + # Verify the file now exists + assert os.path.exists(config_file_path) + + # Verify the file contains a valid empty config + with open(config_file_path, "r") as f: + config_content = f.read() + loaded_config = yaml.safe_load(config_content) + + # Validate the config + GuardrailConfig.model_validate(loaded_config) + + +@pytest.mark.asyncio +async def mock_get_all_policies(config_file_path: str) -> list[str]: + return ["some_guardrail"] + + +@patch("mcp_scan_server.routes.policies.get_all_policies", mock_get_all_policies) +def test_get_policy_endpoint(): + """Test that the get_policy returns a dict with a list of policies.""" + response = client.get("/api/v1/dataset/byuser/testuser/test_dataset/policy") + assert response.status_code == 200 + assert response.json() == {"policies": ["some_guardrail"]} + + +# fixture policy_str +@pytest.fixture +def error_one_policy_str(): + return """ + raise "error_one" if: + (msg: Message) + "error_one" in msg.content + """ + + +@pytest.fixture +def error_two_policy_str(): + return """ + raise "error_two" if: + (msg: Message) + "error_two" in msg.content + """ + + +@pytest.fixture +def detect_random_policy_str(): + return """ + raise "error_random" if: + (msg: Message) + "random" in msg.content + """ + + +@pytest.fixture +def detect_simple_flow_policy_str(): + return """ + raise "error_flow" if: + (msg1: Message) + (msg2: ToolOutput) + msg1.content == "request_tool" + msg2.content == "tool_output" + """ + + +@pytest.fixture +def simple_trace(): + return [ + {"content": "error_one", "role": "user"}, + {"content": "error_two", "role": "user"}, + ] + + +@pytest.fixture +def simple_flow_trace(): + return [ + {"content": "request_tool", "role": "user"}, + {"content": "some_response", "role": "assistant"}, + {"content": "tool_output", "role": "tool"}, + ] + + +@pytest.mark.asyncio +async def test_check_policy_raises_exception_when_trace_violates_policy(error_two_policy_str, simple_trace): + """Test that the check_policy endpoint raises an exception when the trace violates the policy.""" + result = await check_policy(error_two_policy_str, simple_trace) + assert len(result.result.errors) == 1 + assert result.result.errors[0].args[0] == "error_two" + + +@pytest.mark.asyncio +async def test_check_policy_only_raises_error_on_last_message(error_one_policy_str, error_two_policy_str, simple_trace): + """Test that the check_policy endpoint only raises an error on the last message.""" + # Should not raise an error as the last message does not contain "error_one" + result_one = await check_policy(error_one_policy_str, simple_trace) + assert len(result_one.result.errors) == 0 + assert result_one.error_message == "" + + # Should raise an error as the last message contains "error_two" + result_two = await check_policy(error_two_policy_str, simple_trace) + assert len(result_two.result.errors) == 1 + assert result_two.result.errors[0].args[0] == "error_two" + + +@pytest.mark.asyncio +async def test_check_policy_returns_success_when_trace_does_not_violate_policy(detect_random_policy_str, simple_trace): + """Test that the check_policy endpoint returns success when the trace does not violate the policy.""" + result = await check_policy(detect_random_policy_str, simple_trace) + assert len(result.result.errors) == 0 + assert result.error_message == "" + + +@pytest.mark.asyncio +async def test_check_policy_catches_flow_violations(detect_simple_flow_policy_str, simple_flow_trace): + """Test that the check_policy endpoint catches flow violations.""" + result = await check_policy(detect_simple_flow_policy_str, simple_flow_trace) + assert len(result.result.errors) == 1 + assert result.result.errors[0].args[0] == "error_flow" From 954d12abd2f11493b9bd56893e7e63da52e0434b Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Wed, 30 Apr 2025 16:22:23 +0200 Subject: [PATCH 15/21] add local-only installation --- src/mcp_scan/cli.py | 27 +++++++++++++++++++++++---- src/mcp_scan/gateway.py | 9 ++++++--- src/mcp_scan_server/server.py | 2 ++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 298dd1d..86f98b7 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -98,6 +98,11 @@ def add_server_arguments(parser): ) +def check_install_args(args): + if args.command == "install" and not args.local_only and not args.api_key: + raise argparse.ArgumentError(None, "argument --api-key is required when --local-only is not set") + + def main(): # Create main parser with description program_name = get_invoking_name() @@ -236,8 +241,7 @@ def main(): install_parser.add_argument( "--api-key", type=str, - required=True, - help="api key for the Invariant Gateway", + help="API key for the Invariant Gateway", ) install_parser.add_argument( "--local-only", @@ -245,6 +249,13 @@ def main(): action="store_true", help="Prevent pushing traces to the explorer.", ) + install_parser.add_argument( + "--mcp-scan-server-port", + type=int, + default=8000, + help="MCP scan server port (default: 8000).", + metavar="PORT", + ) # uninstall uninstall_parser = subparsers.add_parser("uninstall", help="Uninstall Invariant Gateway") @@ -314,12 +325,20 @@ def main(): MCPScanner(**vars(args)).inspect() sys.exit(0) elif args.command == "install": - installer = MCPGatewayInstaller(paths=args.files) + try: + check_install_args(args) + except argparse.ArgumentError as e: + parser.error(e) + + invariant_api_url = ( + f"http://localhost:{args.mcp_scan_server_port}" if args.local_only else "https://explorer.invariantlabs.ai" + ) + installer = MCPGatewayInstaller(paths=args.files, invariant_api_url=invariant_api_url) installer.install( gateway_config=MCPGatewayConfig( project_name=args.project_name, push_explorer=not args.local_only, - api_key=args.api_key, + api_key=args.api_key or "", ), verbose=True, ) diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py index 5338e27..a55e8e5 100644 --- a/src/mcp_scan/gateway.py +++ b/src/mcp_scan/gateway.py @@ -52,6 +52,7 @@ def is_invariant_installed(server: StdioServer) -> bool: def install_gateway( server: StdioServer, config: MCPGatewayConfig, + invariant_api_url: str = "https://explorer.invariantlabs.ai", ) -> StdioServer: """Install the gateway for the given server.""" if is_invariant_installed(server): @@ -67,7 +68,7 @@ def install_gateway( + (["--push-explorer"] if config.push_explorer else []) + ["--exec", server.command] + (server.args if server.args else []), - env=server.env | {"INVARIANT_API_KEY": config.api_key}, + env=server.env | {"INVARIANT_API_KEY": config.api_key, "INVARIANT_API_URL": invariant_api_url}, ) @@ -80,7 +81,7 @@ def uninstall_gateway( assert isinstance(server.args, list), "args is not a list" args = parser.parse_args(server.args[2:]) - new_env = {k: v for k, v in server.env.items() if k != "INVARIANT_API_KEY"} + new_env = {k: v for k, v in server.env.items() if k != "INVARIANT_API_KEY" and k != "INVARIANT_API_URL"} assert args.exec is not None, "exec is None" assert args.exec, "exec is empty" return StdioServer( @@ -108,8 +109,10 @@ class MCPGatewayInstaller: def __init__( self, paths: list[str], + invariant_api_url: str = "https://explorer.invariantlabs.ai", ) -> None: self.paths = paths + self.invariant_api_url = invariant_api_url def install( self, @@ -135,7 +138,7 @@ def install( for name, server in config.get_servers().items(): if isinstance(server, StdioServer): try: - new_servers[name] = install_gateway(server, gateway_config) + new_servers[name] = install_gateway(server, gateway_config, self.invariant_api_url) path_print_tree.add(format_install_line(server=name, status="Installed", success=True)) except MCPServerAlreadyGateway: new_servers[name] = server diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index fa61822..348cc6c 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -1,3 +1,4 @@ +import rich import uvicorn from fastapi import FastAPI @@ -30,4 +31,5 @@ def __init__(self, port: int = 8000, config_file_path: str | None = None): def run(self): """Run the MCP scan server.""" + rich.print("[bold green]Starting MCP-scan server.[/bold green]") uvicorn.run(self.app, host="0.0.0.0", port=self.port) From d3f49612b8f4d1c5450d17d650ad8241d538d6a5 Mon Sep 17 00:00:00 2001 From: Marc Fischer Date: Wed, 30 Apr 2025 18:23:59 +0200 Subject: [PATCH 16/21] fixes --- pyproject.toml | 2 -- src/mcp_scan/MCPScanner.py | 4 ++-- src/mcp_scan/cli.py | 38 +++++++++++++++++------------------ src/mcp_scan/mcp_client.py | 9 ++------- tests/unit/test_mcp_client.py | 8 ++++---- 5 files changed, 27 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 072a34f..9cd805c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,6 @@ classifiers = [ dependencies = [ "mcp[cli]>=1.6.0", "rich>=14.0.0", - "aiofiles>=23.1.0", - "types-aiofiles", "pyjson5>=1.6.8", "pydantic>=2.11.2", "lark>=1.1.9", diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index dab1812..eae39d6 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -5,7 +5,7 @@ from mcp_scan.models import CrossRefResult, ScanException, ScanPathResult, ServerScanResult -from .mcp_client import a_scan_mcp_config_file, check_server_with_timeout +from .mcp_client import check_server_with_timeout, scan_mcp_config_file from .StorageFile import StorageFile from .verify_api import verify_server @@ -83,7 +83,7 @@ def hook(self, signal: str, async_callback: Callable[[str, Any], None]): async def get_servers_from_path(self, path: str) -> ScanPathResult: result = ScanPathResult(path=path) try: - servers = (await a_scan_mcp_config_file(path)).get_servers() + servers = scan_mcp_config_file(path).get_servers() result.servers = [ ServerScanResult(name=server_name, server=server) for server_name, server in servers.items() ] diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 28a6f74..a9075ec 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -106,7 +106,7 @@ def check_install_args(args): raise argparse.ArgumentError(None, "argument --api-key is required when --local-only is not set") -async def main(): +def main(): # Create main parser with description program_name = get_invoking_name() parser = argparse.ArgumentParser( @@ -306,7 +306,7 @@ async def main(): args = parser.parse_args(["scan"] if len(sys.argv) == 1 else None) # Display version banner - if not args.json: + if not (hasattr(args, "json") and args.json): rich.print(f"[bold blue]Invariant MCP-scan v{version_info}[/bold blue]\n") # Handle commands @@ -336,12 +336,7 @@ async def main(): whitelist_parser.print_help() sys.exit(1) elif args.command == "inspect": - result = await MCPScanner(**vars(args)).inspect() - if args.json: - result = dict((r.path, r.model_dump()) for r in result) - print(json.dumps(result, indent=2)) - else: - print_scan_result(result) + asyncio.run(run_scan_inspect(mode="inspect", args=args)) sys.exit(0) elif args.command == "install": try: @@ -361,11 +356,9 @@ async def main(): ), verbose=True, ) - # install logic here elif args.command == "uninstall": installer = MCPGatewayInstaller(paths=args.files) installer.uninstall(verbose=True) - # uninstall logic here elif args.command == "whitelist": if args.reset: MCPScanner(**vars(args)).reset_whitelist() @@ -381,14 +374,7 @@ async def main(): rich.print("[bold red]Please provide a name and hash.[/bold red]") sys.exit(1) elif args.command == "scan" or args.command is None: # default to scan - async with MCPScanner(**vars(args)) as scanner: - # scanner.hook('path_scanned', print_path_scanned) - result = await scanner.scan() - if args.json: - result = dict((r.path, r.model_dump()) for r in result) - print(json.dumps(result, indent=2)) - else: - print_scan_result(result) + asyncio.run(run_scan_inspect(args=args)) sys.exit(0) elif args.command == "server": sf = StorageFile(args.storage_file) @@ -409,5 +395,19 @@ async def main(): sys.exit(1) +async def run_scan_inspect(mode="scan", args=None): + async with MCPScanner(**vars(args)) as scanner: + # scanner.hook('path_scanned', print_path_scanned) + if mode == "scan": + result = await scanner.scan() + elif mode == "inspect": + result = await scanner.inspect() + if args.json: + result = dict((r.path, r.model_dump()) for r in result) + print(json.dumps(result, indent=2)) + else: + print_scan_result(result) + + if __name__ == "__main__": - asyncio.run(main()) + main() diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index 1315590..14116d6 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -2,7 +2,6 @@ import os from typing import AsyncContextManager, Type -import aiofiles # type: ignore import pyjson5 from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client @@ -86,10 +85,6 @@ async def check_server_with_timeout( def scan_mcp_config_file(path: str) -> MCPConfig: - return asyncio.run(a_scan_mcp_config_file(path)) - - -async def a_scan_mcp_config_file(path: str) -> MCPConfig: path = os.path.expanduser(path) def parse_and_validate(config: dict) -> MCPConfig: @@ -113,8 +108,8 @@ def parse_and_validate(config: dict) -> MCPConfig: ) raise Exception("Could not parse config file") - async with aiofiles.open(path, "r") as f: - content = await f.read() + with open(path, "r") as f: + content = f.read() # use json5 to support comments as in vscode config = pyjson5.loads(content) # try to parse model diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index 9137781..ef60f3f 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -6,18 +6,18 @@ import pytest from pytest_lazy_fixtures import lf -from mcp_scan.mcp_client import a_scan_mcp_config_file, check_server +from mcp_scan.mcp_client import check_server, scan_mcp_config_file from mcp_scan.models import StdioServer @pytest.mark.anyio @pytest.mark.parametrize("sample_config", [lf("claudestyle_config"), lf("vscode_mcp_config"), lf("vscode_config")]) -async def test_scan_mcp_config(sample_config): +def test_scan_mcp_config(sample_config): with tempfile.NamedTemporaryFile(mode="w") as temp_file: temp_file.write(sample_config) temp_file.flush() - await a_scan_mcp_config_file(temp_file.name) + scan_mcp_config_file(temp_file.name) @pytest.mark.anyio @@ -83,7 +83,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): @pytest.mark.anyio async def test_mcp_server(): path = "tests/mcp_servers/mcp_config.json" - servers = (await a_scan_mcp_config_file(path)).get_servers() + servers = scan_mcp_config_file(path).get_servers() for name, server in servers.items(): prompts, resources, tools = await check_server(server, 5, False) if name == "Math": From c83bee8f74a4bbeb402b2c2b118a031cb4460200 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Mon, 5 May 2025 15:34:05 +0200 Subject: [PATCH 17/21] wip --- src/mcp_scan/cli.py | 217 ++++++++++++++++--------- src/mcp_scan_server/activity_logger.py | 114 +++++++++++++ src/mcp_scan_server/routes/push.py | 19 ++- src/mcp_scan_server/routes/trace.py | 13 +- src/mcp_scan_server/server.py | 30 +++- test-client.py | 32 ++++ test.sh | 3 + 7 files changed, 338 insertions(+), 90 deletions(-) create mode 100644 src/mcp_scan_server/activity_logger.py create mode 100644 test-client.py create mode 100644 test.sh diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index a9075ec..f2f5a92 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -99,11 +99,73 @@ def add_server_arguments(parser): help="Suppress stdout/stderr from MCP servers (default: True)", metavar="BOOL", ) + server_group.add_argument( + "--pretty", + type=str, + default="oneline", + choices=["oneline", "compact", "full"], + help="Pretty print the output (default: compact)", + ) + + +def add_install_arguments(parser): + parser.add_argument( + "files", + type=str, + nargs="*", + default=WELL_KNOWN_MCP_PATHS, + help=( + "Different file locations to scan. " + "This can include custom file locations as long as " + "they are in an expected format, including Claude, " + "Cursor or VSCode format." + ), + ) + parser.add_argument( + "--project_name", + type=str, + default="mcp-gateway", + help="Project name for the Invariant Gateway", + ) + parser.add_argument( + "--api-key", + type=str, + help="API key for the Invariant Gateway", + ) + parser.add_argument( + "--local-only", + default=False, + action="store_true", + help="Prevent pushing traces to the explorer.", + ) + parser.add_argument( + "--mcp-scan-server-port", + type=int, + default=8000, + help="MCP scan server port (default: 8000).", + metavar="PORT", + ) + + +def add_uninstall_arguments(parser): + parser.add_argument( + "files", + type=str, + nargs="*", + default=WELL_KNOWN_MCP_PATHS, + help=( + "Different file locations to scan. " + "This can include custom file locations as long as " + "they are in an expected format, including Claude, Cursor or VSCode format." + ), + ) def check_install_args(args): if args.command == "install" and not args.local_only and not args.api_key: - raise argparse.ArgumentError(None, "argument --api-key is required when --local-only is not set") + raise argparse.ArgumentError( + None, "argument --api-key is required when --local-only is not set" + ) def main(): @@ -187,7 +249,8 @@ def main(): "whitelist", help="Manage the whitelist of approved entities", description=( - "View, add, or reset whitelisted entities. " "Whitelisted entities bypass security checks during scans." + "View, add, or reset whitelisted entities. " + "Whitelisted entities bypass security checks during scans." ), ) add_common_arguments(whitelist_parser) @@ -233,56 +296,13 @@ def main(): ) # install install_parser = subparsers.add_parser("install", help="Install Invariant Gateway") - install_parser.add_argument( - "files", - type=str, - nargs="*", - default=WELL_KNOWN_MCP_PATHS, - help=( - "Different file locations to scan. " - "This can include custom file locations as long as " - "they are in an expected format, including Claude, " - "Cursor or VSCode format." - ), - ) - install_parser.add_argument( - "--project_name", - type=str, - default="mcp-gateway", - help="Project name for the Invariant Gateway", - ) - install_parser.add_argument( - "--api-key", - type=str, - help="API key for the Invariant Gateway", - ) - install_parser.add_argument( - "--local-only", - default=False, - action="store_true", - help="Prevent pushing traces to the explorer.", - ) - install_parser.add_argument( - "--mcp-scan-server-port", - type=int, - default=8000, - help="MCP scan server port (default: 8000).", - metavar="PORT", - ) + add_install_arguments(install_parser) # uninstall - uninstall_parser = subparsers.add_parser("uninstall", help="Uninstall Invariant Gateway") - uninstall_parser.add_argument( - "files", - type=str, - nargs="*", - default=WELL_KNOWN_MCP_PATHS, - help=( - "Different file locations to scan. " - "This can include custom file locations as long as " - "they are in an expected format, including Claude, Cursor or VSCode format." - ), + uninstall_parser = subparsers.add_parser( + "uninstall", help="Uninstall Invariant Gateway" ) + add_uninstall_arguments(uninstall_parser) # HELP command help_parser = subparsers.add_parser( # noqa: F841 @@ -302,6 +322,21 @@ def main(): ) add_common_arguments(server_parser) + # PROXY command + proxy_parser = subparsers.add_parser( + "proxy", help="Installs and proxies MCP requests, uninstalls on exit" + ) + proxy_parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + metavar="PORT", + ) + add_common_arguments(proxy_parser) + add_server_arguments(proxy_parser) + add_install_arguments(proxy_parser) + # Parse arguments (default to 'scan' if no command provided) args = parser.parse_args(["scan"] if len(sys.argv) == 1 else None) @@ -309,6 +344,44 @@ def main(): if not (hasattr(args, "json") and args.json): rich.print(f"[bold blue]Invariant MCP-scan v{version_info}[/bold blue]\n") + def install(): + try: + check_install_args(args) + except argparse.ArgumentError as e: + parser.error(e) + + invariant_api_url = ( + f"http://localhost:{args.mcp_scan_server_port}" + if args.local_only + else "https://explorer.invariantlabs.ai" + ) + installer = MCPGatewayInstaller( + paths=args.files, invariant_api_url=invariant_api_url + ) + installer.install( + gateway_config=MCPGatewayConfig( + project_name=args.project_name, + push_explorer=not args.local_only, + api_key=args.api_key or "", + ), + verbose=True, + ) + + def uninstall(): + installer = MCPGatewayInstaller(paths=args.files) + installer.uninstall(verbose=True) + + def server(on_exit=None): + sf = StorageFile(args.storage_file) + guardrails_config_path = sf.create_guardrails_config() + mcp_scan_server = MCPScanServer( + port=args.port, + config_file_path=guardrails_config_path, + on_exit=on_exit, + pretty=args.pretty + ) + mcp_scan_server.run() + # Handle commands if args.command == "help": parser.print_help() @@ -319,7 +392,9 @@ def main(): sf.reset_whitelist() rich.print("[bold]Whitelist reset[/bold]") sys.exit(0) - elif all(map(lambda x: x is None, [args.type, args.name, args.hash])): # no args + elif all( + map(lambda x: x is None, [args.type, args.name, args.hash]) + ): # no args sf.print_whitelist() sys.exit(0) elif all(map(lambda x: x is not None, [args.type, args.name, args.hash])): @@ -332,33 +407,20 @@ def main(): sf.print_whitelist() sys.exit(0) else: - rich.print("[bold red]Please provide all three parameters: type, name, and hash.[/bold red]") + rich.print( + "[bold red]Please provide all three parameters: type, name, and hash.[/bold red]" + ) whitelist_parser.print_help() sys.exit(1) elif args.command == "inspect": asyncio.run(run_scan_inspect(mode="inspect", args=args)) sys.exit(0) elif args.command == "install": - try: - check_install_args(args) - except argparse.ArgumentError as e: - parser.error(e) - - invariant_api_url = ( - f"http://localhost:{args.mcp_scan_server_port}" if args.local_only else "https://explorer.invariantlabs.ai" - ) - installer = MCPGatewayInstaller(paths=args.files, invariant_api_url=invariant_api_url) - installer.install( - gateway_config=MCPGatewayConfig( - project_name=args.project_name, - push_explorer=not args.local_only, - api_key=args.api_key or "", - ), - verbose=True, - ) + install() + sys.exit(0) elif args.command == "uninstall": - installer = MCPGatewayInstaller(paths=args.files) - installer.uninstall(verbose=True) + uninstall() + sys.exit(0) elif args.command == "whitelist": if args.reset: MCPScanner(**vars(args)).reset_whitelist() @@ -377,17 +439,12 @@ def main(): asyncio.run(run_scan_inspect(args=args)) sys.exit(0) elif args.command == "server": - sf = StorageFile(args.storage_file) - guardrails_config_path = sf.create_guardrails_config() - mcp_scan_server = MCPScanServer(port=args.port, config_file_path=guardrails_config_path) - mcp_scan_server.run() - sys.exit(0) - elif args.command == "server": - sf = StorageFile(args.storage_file) - guardrails_config_path = sf.create_guardrails_config() - mcp_scan_server = MCPScanServer(port=args.port, config_file_path=guardrails_config_path) - mcp_scan_server.run() + server() sys.exit(0) + elif args.command == "proxy": + args.local_only = True + # install() + server(on_exit=uninstall) else: # This shouldn't happen due to argparse's handling rich.print(f"[bold red]Unknown command: {args.command}[/bold red]") diff --git a/src/mcp_scan_server/activity_logger.py b/src/mcp_scan_server/activity_logger.py new file mode 100644 index 0000000..6ef4875 --- /dev/null +++ b/src/mcp_scan_server/activity_logger.py @@ -0,0 +1,114 @@ +from typing import Literal +import uuid +import rich +import json + +from fastapi import APIRouter, FastAPI, Request +from invariant_sdk.types.push_traces import PushTracesResponse + +from rich import print +from rich.panel import Panel +from rich.syntax import Syntax +from rich.markdown import Markdown +from rich.rule import Rule +from textwrap import shorten + +class ActivityLogger: + def __init__(self, pretty: Literal["oneline", "compact", "full"] = "compact"): + self.cached_metadata = {} + # level of pretty printing + self.pretty = pretty + + async def handle_push(self, messages, metadata): + """ + Handles a push request with the given messages and metadata. + """ + for i, batch_items in enumerate(messages): + trace_id = str(uuid.uuid4()) + self.cached_metadata[trace_id] = metadata[i] + await self.log(batch_items, metadata[i]) + + return trace_id + + async def handle_append(self, trace_id: str, messages: list[dict]): + """ + Handles an append request with the given trace ID and messages. + """ + metadata = self.cached_metadata.get(trace_id, None) + await self.log(messages, metadata) + + async def log(self, messages, metadata=None): + """ + Console-logs the relevant parts of the given messages and metadata. + """ + + client = metadata.get("client", "Unknown Client").capitalize() + server = metadata.get("mcp_server", "Unknown Server").capitalize() + user = metadata.get("user", "Unknown User") + + for tc in tool_calls(messages): + name = tc['name'] + + print(Rule()) + print(f"● [bold blue]{client}[/bold blue] (@[bold red]{user}[/bold red]) used [bold green]{server}[/bold green] to [bold green]{name}[/bold green]") + print(Rule()) + + if self.pretty != 'oneline': + args = tc.get("arguments", {}) + result = tc.get("result", "") + + if self.pretty == 'compact': + truncated_result = truncate_preserving_whitespace(result) + + print(Syntax(json.dumps(args, indent=2), "json", theme="monokai")) + print(Rule(style="grey50")) + print(Syntax(truncated_result, "json" if not truncated_result.startswith("Error") else "pytb", theme="monokai")) + print(Rule(style="grey50")) + else: + print(Syntax(json.dumps(args, indent=2), "json", theme="monokai")) + print(Rule(style="grey50")) + print(Syntax(result, "json" if not result.startswith("Error") else "pytb", theme="monokai")) + print(Rule(style="grey50")) + + +def tool_calls(messages: list[dict]) -> list[dict]: + calls = {} + + # First pass: index tool call requests + for msg in messages: + if 'tool_calls' in msg: + for call in (msg['tool_calls'] or []): + calls[call['id']] = { + "name": call['function'].get('name', ""), + "arguments": call['function'].get('arguments', {}) + } + + # Second pass: find responses with matching tool_call_id + for msg in messages: + if msg.get('tool_call_id') in calls: + result_texts = [c['text'] for c in msg.get('content', []) if c['type'] == 'text'] + calls[msg['tool_call_id']]["result"] = "\n".join(result_texts) + + return list(calls.values()) + +def truncate_preserving_whitespace(text, max_lines=20, max_chars=2000): + lines = text.splitlines() + truncated = "\n".join(lines[:max_lines]) + if len(truncated) > max_chars: + truncated = truncated[:max_chars] + "\n... [truncated]" + elif len(lines) > max_lines: + truncated += "# \n... [truncated]" + return truncated + +async def get_activity_logger(request: Request) -> ActivityLogger: + """ + Returns a singleton instance of the ActivityLogger. + """ + return request.app.state.activity_logger + +def setup_activity_logger(app: FastAPI, pretty: Literal["oneline", "compact", "full"] = "compact"): + """ + Sets up the ActivityLogger as a dependency for the given FastAPI app. + """ + app.state.activity_logger = ActivityLogger(pretty=pretty) + diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index 2da75ba..dc91e7e 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -1,12 +1,23 @@ import uuid +import rich +import json -from fastapi import APIRouter +from fastapi import APIRouter, Request, Depends +from typing import Annotated from invariant_sdk.types.push_traces import PushTracesResponse -router = APIRouter() +from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger +router = APIRouter() @router.post("/trace") -async def push_trace(): +async def push_trace(request: Request, activity_logger: Annotated[ActivityLogger, Depends(get_activity_logger)]) -> PushTracesResponse: """Push a trace. For now, this is a dummy response.""" - return PushTracesResponse(id=[str(uuid.uuid4())], success=True) + body = await request.json() + metadata = body.get("metadata", [{}]) + messages = body.get("messages", [[]]) + + trace_id = await activity_logger.handle_push(messages, metadata) + + # return the trace ID + return PushTracesResponse(id=[trace_id], success=True) diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index a64c646..b042f89 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -1,9 +1,18 @@ -from fastapi import APIRouter, Request +from fastapi import APIRouter, Request, Depends +from typing import Annotated + +from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger router = APIRouter() @router.post("/{trace_id}/messages") -async def append_messages(request: Request): +async def append_messages(trace_id: str, request: Request, activity_logger: Annotated[ActivityLogger, Depends(get_activity_logger)]): """Append messages to a trace. For now this is a dummy response.""" + + body = await request.json() + messages = body.get("messages", []) + + await activity_logger.handle_append(trace_id, messages) + return {"success": True} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index 348cc6c..e0ca6fb 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -1,6 +1,9 @@ import rich import uvicorn from fastapi import FastAPI +from typing import Literal, Optional + +from mcp_scan_server.activity_logger import setup_activity_logger from .routes.policies import router as policies_router from .routes.push import router as push_router @@ -15,13 +18,18 @@ class MCPScanServer: Args: port: The port to run the server on. config_file_path: The path to the config file. + on_exit: A callback function to be called on exit of the server. + log_level: The log level for the server. """ - def __init__(self, port: int = 8000, config_file_path: str | None = None): + def __init__(self, port: int = 8000, config_file_path: str | None = None, on_exit: Optional[callable] = None, log_level: str = "error", pretty: Literal["oneline", "compact", "full"] = "compact"): self.port = port self.config_file_path = config_file_path + self.on_exit = on_exit + self.log_level = log_level + self.pretty = pretty - self.app = FastAPI() + self.app = FastAPI(lifespan=self.life_span) self.app.state.config_file_path = config_file_path self.app.include_router(policies_router, prefix="/api/v1") @@ -29,7 +37,21 @@ def __init__(self, port: int = 8000, config_file_path: str | None = None): self.app.include_router(dataset_trace_router, prefix="/api/v1/trace") self.app.include_router(user_router, prefix="/api/v1/user") + async def on_startup(self): + """Startup event for the FastAPI app.""" + rich.print("[bold green]MCP-scan server started.[/bold green]") + + setup_activity_logger(self.app, pretty=self.pretty) + + async def life_span(self, app: FastAPI): + """Lifespan event for the FastAPI app.""" + await self.on_startup() + + yield + + if callable(self.on_exit): + self.on_exit() + def run(self): """Run the MCP scan server.""" - rich.print("[bold green]Starting MCP-scan server.[/bold green]") - uvicorn.run(self.app, host="0.0.0.0", port=self.port) + uvicorn.run(self.app, host="0.0.0.0", port=self.port, log_level=self.log_level) diff --git a/test-client.py b/test-client.py new file mode 100644 index 0000000..47ebf3c --- /dev/null +++ b/test-client.py @@ -0,0 +1,32 @@ +import asyncio +import json +from pathlib import Path +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.stdio import stdio_client + +CONFIG_PATH = Path("/Users/luca/.cursor/mcp.json") +SERVER_KEY = "whatsapp-mcp" + +def load_server_params(key: str) -> StdioServerParameters: + config = json.loads(CONFIG_PATH.read_text()) + server_cfg = config["mcpServers"][key] + return StdioServerParameters( + command=server_cfg["command"], + args=server_cfg.get("args", []), + env=server_cfg.get("env", {}) + ) + +async def run(): + server_params = load_server_params(SERVER_KEY) + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + await session.list_tools() + await session.call_tool("list_chats", arguments={"limit": 20, "include_last_message": True, "sort_by": "last_active"}) + + await session.call_tool("send_message", arguments={"chat_id": "123", "message": "Hello, world!"}) + + await asyncio.sleep(0.2) + +if __name__ == "__main__": + asyncio.run(run()) \ No newline at end of file diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..b1d5c62 --- /dev/null +++ b/test.sh @@ -0,0 +1,3 @@ +npx concurrently -p none -k \ + "uv run mcp-scan proxy --pretty full" \ + "sleep 0.1 && python test-client.py" \ No newline at end of file From b65701fe5d111301fedb9d36da891eab235c4fe7 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Tue, 6 May 2025 20:46:41 +0200 Subject: [PATCH 18/21] activity logging --- src/mcp_scan/cli.py | 57 +++++------ src/mcp_scan/gateway.py | 73 +++++++++----- src/mcp_scan/paths.py | 87 ++++++++++++++++ src/mcp_scan_server/activity_logger.py | 132 ++++++++++++++----------- src/mcp_scan_server/routes/policies.py | 18 +++- src/mcp_scan_server/routes/push.py | 3 +- src/mcp_scan_server/routes/trace.py | 2 +- 7 files changed, 254 insertions(+), 118 deletions(-) create mode 100644 src/mcp_scan/paths.py diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index f2f5a92..e2b1703 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -13,6 +13,7 @@ from .printer import print_scan_result from .StorageFile import StorageFile from .version import version_info +from .paths import WELL_KNOWN_MCP_PATHS, client_shorthands_to_paths def get_invoking_name(): @@ -35,35 +36,6 @@ def get_invoking_name(): def str2bool(v: str) -> bool: return v.lower() in ("true", "1", "t", "y", "yes") - -if sys.platform == "linux" or sys.platform == "linux2": - WELL_KNOWN_MCP_PATHS = [ - "~/.codeium/windsurf/mcp_config.json", # windsurf - "~/.cursor/mcp.json", # cursor - "~/.vscode/mcp.json", # vscode - "~/.config/Code/User/settings.json", # vscode linux - ] -elif sys.platform == "darwin": - # OS X - WELL_KNOWN_MCP_PATHS = [ - "~/.codeium/windsurf/mcp_config.json", # windsurf - "~/.cursor/mcp.json", # cursor - "~/Library/Application Support/Claude/claude_desktop_config.json", # Claude Desktop mac - "~/.vscode/mcp.json", # vscode - "~/Library/Application Support/Code/User/settings.json", # vscode mac - ] -elif sys.platform == "win32": - WELL_KNOWN_MCP_PATHS = [ - "~/.codeium/windsurf/mcp_config.json", # windsurf - "~/.cursor/mcp.json", # cursor - "~/AppData/Roaming/Claude/claude_desktop_config.json", # Claude Desktop windows - "~/.vscode/mcp.json", # vscode - "~/AppData/Roaming/Code/User/settings.json", # vscode windows - ] -else: - WELL_KNOWN_MCP_PATHS = [] - - def add_common_arguments(parser): """Add arguments that are common to multiple commands.""" parser.add_argument( @@ -138,6 +110,12 @@ def add_install_arguments(parser): action="store_true", help="Prevent pushing traces to the explorer.", ) + parser.add_argument( + "--gateway-dir", + type=str, + help="Source directory for the Invariant Gateway. Set this, if you want to install a custom gateway implementation. (default: the published package is used).", + default=None, + ) parser.add_argument( "--mcp-scan-server-port", type=int, @@ -163,9 +141,15 @@ def add_uninstall_arguments(parser): def check_install_args(args): if args.command == "install" and not args.local_only and not args.api_key: - raise argparse.ArgumentError( - None, "argument --api-key is required when --local-only is not set" - ) + # prompt for api key + print("To install mcp-scan with remote logging, you need an Invariant API key (https://explorer.invariantlabs.ai/settings).\n") + args.api_key = input("API key (or just press enter to install with --local-only): ") + if not args.api_key: + args.local_only = True + + # raise argparse.ArgumentError( + # None, "argument --api-key is required when --local-only is not set" + # ) def main(): @@ -339,6 +323,9 @@ def main(): # Parse arguments (default to 'scan' if no command provided) args = parser.parse_args(["scan"] if len(sys.argv) == 1 else None) + + # postprocess the files argument (if shorthands are used) + args.files = client_shorthands_to_paths(args.files) # Display version banner if not (hasattr(args, "json") and args.json): @@ -361,8 +348,9 @@ def install(): installer.install( gateway_config=MCPGatewayConfig( project_name=args.project_name, - push_explorer=not args.local_only, + push_explorer=True, api_key=args.api_key or "", + source_dir=args.gateway_dir, ), verbose=True, ) @@ -443,7 +431,8 @@ def server(on_exit=None): sys.exit(0) elif args.command == "proxy": args.local_only = True - # install() + install() + print("[Proxy installed, you may need to restart/reload your MCP clients to use it]") server(on_exit=uninstall) else: # This shouldn't happen due to argparse's handling diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py index 641ad3e..8fef437 100644 --- a/src/mcp_scan/gateway.py +++ b/src/mcp_scan/gateway.py @@ -1,5 +1,6 @@ import argparse import os +from typing import Optional import rich from pydantic import BaseModel @@ -8,6 +9,7 @@ from mcp_scan.mcp_client import scan_mcp_config_file from mcp_scan.models import MCPConfig, SSEServer, StdioServer +from mcp_scan.paths import get_client_from_path from mcp_scan.printer import format_path_line parser = argparse.ArgumentParser( @@ -15,15 +17,6 @@ prog="invariant-gateway@latest mcp", ) -parser.add_argument( - "--project-name", - type=str, - required=True, -) -parser.add_argument( - "--push-explorer", - action="store_true", -) parser.add_argument("--exec", type=str, required=True, nargs=argparse.REMAINDER) @@ -40,35 +33,70 @@ class MCPGatewayConfig(BaseModel): push_explorer: bool api_key: str + # the source directory of the gateway implementation to use + # (if None, uses the published package) + source_dir: Optional[str] = None + def is_invariant_installed(server: StdioServer) -> bool: if server.args is None: return False if not server.args: return False - return server.args[0] == "invariant-gateway@latest" + return any("invariant-gateway" in a for a in server.args) def install_gateway( server: StdioServer, config: MCPGatewayConfig, invariant_api_url: str = "https://explorer.invariantlabs.ai", + extra_metadata: dict[str, str] = {}, ) -> StdioServer: """Install the gateway for the given server.""" if is_invariant_installed(server): raise MCPServerAlreadyGateway() - return StdioServer( - command="uvx", - args=[ - "invariant-gateway@latest", + + env = server.env | {"INVARIANT_API_KEY": config.api_key or "", "INVARIANT_API_URL": invariant_api_url, "GUARDRAILS_API_URL": invariant_api_url} + + cmd = "uvx" + base_args = [ + "invariant-gateway@latest", + "mcp", + ] + + # if running gateway from source-dir, use 'uv run' instead + if config.source_dir: + cmd = "uv" + base_args = [ + "run", + "--with", "mcp", + "--directory", + config.source_dir, + "invariant-gateway", + "mcp" + ] + + flags = [ "--project-name", config.project_name, + *(["--push-explorer"] if config.push_explorer else []), ] - + (["--push-explorer"] if config.push_explorer else []) - + ["--exec", server.command] - + (server.args if server.args else []), - env=server.env | {"INVARIANT_API_KEY": config.api_key, "INVARIANT_API_URL": invariant_api_url}, + # add extra metadata flags + for k, v in extra_metadata.items(): + flags.append(f"--metadata-{k}={v}") + + # add exec section + flags += [ + *["--exec", server.command], + *(server.args if server.args else []) + ] + + # return new server config + return StdioServer( + command=cmd, + args=base_args + flags, + env=env ) @@ -80,8 +108,8 @@ def uninstall_gateway( raise MCPServerIsNotGateway() assert isinstance(server.args, list), "args is not a list" - args = parser.parse_args(server.args[2:]) - new_env = {k: v for k, v in server.env.items() if k != "INVARIANT_API_KEY" and k != "INVARIANT_API_URL"} + args, unknown = parser.parse_known_args(server.args[2:]) + new_env = {k: v for k, v in server.env.items() if k != "INVARIANT_API_KEY" and k != "INVARIANT_API_URL" and k != "GUARDRAILS_API_URL"} assert args.exec is not None, "exec is None" assert args.exec, "exec is empty" return StdioServer( @@ -138,13 +166,14 @@ def install( for name, server in config.get_servers().items(): if isinstance(server, StdioServer): try: - new_servers[name] = install_gateway(server, gateway_config, self.invariant_api_url) + new_servers[name] = install_gateway(server, gateway_config, self.invariant_api_url, {"server": name, "client": get_client_from_path(path)}) path_print_tree.add(format_install_line(server=name, status="Installed", success=True)) except MCPServerAlreadyGateway: new_servers[name] = server path_print_tree.add(format_install_line(server=name, status="Already installed", success=True)) - except Exception: + except Exception as e: new_servers[name] = server + print(f"Failed to install gateway for {name}", e) path_print_tree.add(format_install_line(server=name, status="Failed to install", success=False)) else: diff --git a/src/mcp_scan/paths.py b/src/mcp_scan/paths.py new file mode 100644 index 0000000..680b568 --- /dev/null +++ b/src/mcp_scan/paths.py @@ -0,0 +1,87 @@ +import sys + +if sys.platform == "linux" or sys.platform == "linux2": + # Linux + CLIENT_PATHS = { + 'windsurf': [ + "~/.codeium/windsurf/mcp_config.json" + ], + 'cursor': [ + "~/.cursor/mcp.json" + ], + 'vscode': [ + "~/.vscode/mcp.json", + "~/.config/Code/User/settings.json" + ], + } + WELL_KNOWN_MCP_PATHS = [path for client, paths in CLIENT_PATHS.items() for path in paths] +elif sys.platform == "darwin": + # OS X + CLIENT_PATHS = { + 'windsurf': [ + "~/.codeium/windsurf/mcp_config.json" + ], + 'cursor': [ + "~/.cursor/mcp.json" + ], + 'claude': [ + "~/Library/Application Support/Claude/claude_desktop_config.json" + ], + 'vscode': [ + "~/.vscode/mcp.json", + "~/Library/Application Support/Code/User/settings.json" + ], + } + WELL_KNOWN_MCP_PATHS = [path for client, paths in CLIENT_PATHS.items() for path in paths] +elif sys.platform == "win32": + CLIENT_PATHS = { + 'windsurf': [ + "~/.codeium/windsurf/mcp_config.json" + ], + 'cursor': [ + "~/.cursor/mcp.json" + ], + 'claude': [ + "~/AppData/Roaming/Claude/claude_desktop_config.json" + ], + 'vscode': [ + "~/.vscode/mcp.json", + "~/AppData/Roaming/Code/User/settings.json" + ], + } + + WELL_KNOWN_MCP_PATHS = [path for client, paths in CLIENT_PATHS.items() for path in paths] +else: + WELL_KNOWN_MCP_PATHS = [] + +def get_client_from_path(path: str) -> str: + """ + Returns the client name from a path. + + Args: + path (str): The path to get the client from. + + Returns: + str: The client name or None if it cannot be guessed from the path. + """ + for client, paths in CLIENT_PATHS.items(): + if path in paths: + return client + return None + +def client_shorthands_to_paths(shorthands: list[str]): + """ + Converts a list of client shorthands to a list of paths. + + Does nothing if the shorthands are already paths. + """ + paths = [] + if any("/" in shorthand for shorthand in shorthands): + return shorthands + + for shorthand in shorthands: + if shorthand in CLIENT_PATHS: + paths.extend(CLIENT_PATHS[shorthand]) + else: + raise ValueError(f"{shorthand} is not a valid client shorthand") + return paths diff --git a/src/mcp_scan_server/activity_logger.py b/src/mcp_scan_server/activity_logger.py index 6ef4875..399c61c 100644 --- a/src/mcp_scan_server/activity_logger.py +++ b/src/mcp_scan_server/activity_logger.py @@ -18,6 +18,15 @@ def __init__(self, pretty: Literal["oneline", "compact", "full"] = "compact"): self.cached_metadata = {} # level of pretty printing self.pretty = pretty + + # (session_id, tool_call_id) -> bool + self.logged_header = {} + self.logged_result = {} + + # (session_id, formatted_output) -> bool + self.logged_output = {} + # last logged (session_id, tool_call_id), so we can skip logging tool call headers if it is directly followed by output + self.last_logged_tool = None async def handle_push(self, messages, metadata): """ @@ -34,71 +43,80 @@ async def handle_append(self, trace_id: str, messages: list[dict]): """ Handles an append request with the given trace ID and messages. """ - metadata = self.cached_metadata.get(trace_id, None) + metadata = self.cached_metadata.get(trace_id, {}) await self.log(messages, metadata) - async def log(self, messages, metadata=None): + def empty_metadata(self): + return { + "client": "Unknown Client", + "mcp_server": "Unknown Server", + "user": None + } + + async def log(self, messages, metadata): """ Console-logs the relevant parts of the given messages and metadata. """ - + session_id = metadata.get("session_id", "") client = metadata.get("client", "Unknown Client").capitalize() server = metadata.get("mcp_server", "Unknown Server").capitalize() - user = metadata.get("user", "Unknown User") - - for tc in tool_calls(messages): - name = tc['name'] - - print(Rule()) - print(f"● [bold blue]{client}[/bold blue] (@[bold red]{user}[/bold red]) used [bold green]{server}[/bold green] to [bold green]{name}[/bold green]") - print(Rule()) - - if self.pretty != 'oneline': - args = tc.get("arguments", {}) - result = tc.get("result", "") - - if self.pretty == 'compact': - truncated_result = truncate_preserving_whitespace(result) - - print(Syntax(json.dumps(args, indent=2), "json", theme="monokai")) - print(Rule(style="grey50")) - print(Syntax(truncated_result, "json" if not truncated_result.startswith("Error") else "pytb", theme="monokai")) - print(Rule(style="grey50")) + user = metadata.get("user", None) + + tool_names = {} + + for msg in messages: + if msg.get('role') == 'tool': + if (session_id, 'output-' + msg.get('tool_call_id')) in self.logged_output: + continue + self.logged_output[(session_id, 'output-' + msg.get('tool_call_id'))] = True + + has_header = self.last_logged_tool == (session_id, msg.get('tool_call_id')) + + if not has_header: + print(Rule()) + # left arrow for output + user_portion = "" if user is None else f" ([bold red]{user}[/bold red])" + name = tool_names.get(msg.get('tool_call_id'), "") + print(f"← [bold blue]{client}[/bold blue]{user_portion} used [bold green]{server}[/bold green] to [bold green]{name}[/bold green]") + print(Rule()) + + # tool output + content = message_content(msg) + if type(content) is str and content.startswith("Error"): + print(Syntax(content, "pytb", theme="monokai")) else: - print(Syntax(json.dumps(args, indent=2), "json", theme="monokai")) - print(Rule(style="grey50")) - print(Syntax(result, "json" if not result.startswith("Error") else "pytb", theme="monokai")) - print(Rule(style="grey50")) - - -def tool_calls(messages: list[dict]) -> list[dict]: - calls = {} - - # First pass: index tool call requests - for msg in messages: - if 'tool_calls' in msg: - for call in (msg['tool_calls'] or []): - calls[call['id']] = { - "name": call['function'].get('name', ""), - "arguments": call['function'].get('arguments', {}) - } - - # Second pass: find responses with matching tool_call_id - for msg in messages: - if msg.get('tool_call_id') in calls: - result_texts = [c['text'] for c in msg.get('content', []) if c['type'] == 'text'] - calls[msg['tool_call_id']]["result"] = "\n".join(result_texts) - - return list(calls.values()) - -def truncate_preserving_whitespace(text, max_lines=20, max_chars=2000): - lines = text.splitlines() - truncated = "\n".join(lines[:max_lines]) - if len(truncated) > max_chars: - truncated = truncated[:max_chars] + "\n... [truncated]" - elif len(lines) > max_lines: - truncated += "# \n... [truncated]" - return truncated + print(Syntax(content, "json", theme="monokai")) + print(Rule()) + + else: + for tc in (msg.get('tool_calls') or []): + name = tc.get('function', {}).get('name', "") + tool_names[tc.get('id')] = name + + if (session_id, tc.get('id')) in self.logged_output: + continue + self.logged_output[(session_id, tc.get('id'))] = True + + self.last_logged_tool = (session_id, tc.get('id')) + + # header + user_portion = "" if user is None else f" ([bold red]{user}[/bold red])" + + print(Rule()) + print(f"→ [bold blue]{client}[/bold blue]{user_portion} used [bold green]{server}[/bold green] to [bold green]{name}[/bold green]") + print(Rule()) + + # tool arguments + print(Syntax(json.dumps(tc.get('arguments', {}), indent=2), "json", theme="monokai")) + + +def message_content(msg: dict) -> str: + if type(msg.get('content')) is str: + return msg.get('content', '') + elif type(msg.get('content')) is list: + return "\n".join([c['text'] for c in msg.get('content', []) if c['type'] == 'text']) + else: + return "" async def get_activity_logger(request: Request) -> ActivityLogger: """ diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py index 69fb25e..9b3eb5f 100644 --- a/src/mcp_scan_server/routes/policies.py +++ b/src/mcp_scan_server/routes/policies.py @@ -1,10 +1,11 @@ import asyncio +import json import os import fastapi import rich import yaml # type: ignore -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends, Request from invariant.analyzer.policy import LocalPolicy from invariant.analyzer.runtime.runtime_errors import ( ExcessivePolicyError, @@ -13,6 +14,8 @@ ) from pydantic import ValidationError +from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger + from ..models import ( BatchCheckRequest, BatchCheckResponse, @@ -129,12 +132,21 @@ def to_json_serializable_dict(obj): @router.post("/policy/check/batch", response_model=BatchCheckResponse) -async def batch_check_policies(request: BatchCheckRequest): +async def batch_check_policies(check_request: BatchCheckRequest, request: fastapi.Request, activity_logger: ActivityLogger = Depends(get_activity_logger)): """Check a policy using the invariant analyzer.""" results = await asyncio.gather( - *[check_policy(policy, request.messages, request.parameters) for policy in request.policies] + *[check_policy(policy, check_request.messages, check_request.parameters) for policy in check_request.policies] ) + metadata = check_request.parameters.get("metadata", {}) + + await activity_logger.log(check_request.messages, { + "client": metadata.get("client"), + "mcp_server": metadata.get("server"), + "user": metadata.get("system_user"), + "session_id": metadata.get("session_id"), + }) + return fastapi.responses.JSONResponse( content={"result": [to_json_serializable_dict(result.to_dict()) for result in results]} ) diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index dc91e7e..f79b7e6 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -17,7 +17,8 @@ async def push_trace(request: Request, activity_logger: Annotated[ActivityLogger metadata = body.get("metadata", [{}]) messages = body.get("messages", [[]]) - trace_id = await activity_logger.handle_push(messages, metadata) + # trace_id = await activity_logger.handle_push(messages, metadata) + trace_id = str(uuid.uuid4()) # return the trace ID return PushTracesResponse(id=[trace_id], success=True) diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index b042f89..2f2876f 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -13,6 +13,6 @@ async def append_messages(trace_id: str, request: Request, activity_logger: Anno body = await request.json() messages = body.get("messages", []) - await activity_logger.handle_append(trace_id, messages) + # await activity_logger.handle_append(trace_id, messages) return {"success": True} From cd5e89e56125aa4630421c21665ea1ec86552150 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Tue, 6 May 2025 22:11:53 +0200 Subject: [PATCH 19/21] test client --- test-client.py | 8 +++++--- test.sh | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test-client.py b/test-client.py index 47ebf3c..dbf21d1 100644 --- a/test-client.py +++ b/test-client.py @@ -17,15 +17,17 @@ def load_server_params(key: str) -> StdioServerParameters: ) async def run(): + await asyncio.sleep(1) server_params = load_server_params(SERVER_KEY) async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() await session.list_tools() - await session.call_tool("list_chats", arguments={"limit": 20, "include_last_message": True, "sort_by": "last_active"}) - + + await session.call_tool("list_chats", arguments={"limit": 20, "include_last_message": True, "sort_by": "last_active"}), + await asyncio.sleep(0.2) + await session.call_tool("send_message", arguments={"chat_id": "123", "message": "Hello, world!"}) - await asyncio.sleep(0.2) if __name__ == "__main__": diff --git a/test.sh b/test.sh index b1d5c62..1f1daac 100644 --- a/test.sh +++ b/test.sh @@ -1,3 +1,3 @@ npx concurrently -p none -k \ - "uv run mcp-scan proxy --pretty full" \ + "uv run mcp-scan proxy --pretty full --gateway-dir /Users/luca/Developer/invariant-gateway" \ "sleep 0.1 && python test-client.py" \ No newline at end of file From fcaafb09ba21a1ff16c9685e32b65f9d753391ce Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Wed, 7 May 2025 09:02:03 +0200 Subject: [PATCH 20/21] tweaks --- src/mcp_scan/cli.py | 14 +++++++------- src/mcp_scan/gateway.py | 10 +++++----- src/mcp_scan_server/activity_logger.py | 24 ++++++------------------ src/mcp_scan_server/routes/push.py | 1 - src/mcp_scan_server/routes/trace.py | 2 -- src/mcp_scan_server/server.py | 6 +++++- 6 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index f6aea6f..96d12c4 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -372,7 +372,7 @@ def main(): if not (hasattr(args, "json") and args.json): rich.print(f"[bold blue]Invariant MCP-scan v{version_info}[/bold blue]\n") - def install(): + async def install(): try: check_install_args(args) except argparse.ArgumentError as e: @@ -386,7 +386,7 @@ def install(): installer = MCPGatewayInstaller( paths=args.files, invariant_api_url=invariant_api_url ) - installer.install( + await installer.install( gateway_config=MCPGatewayConfig( project_name=args.project_name, push_explorer=True, @@ -396,9 +396,9 @@ def install(): verbose=True, ) - def uninstall(): + async def uninstall(): installer = MCPGatewayInstaller(paths=args.files) - installer.uninstall(verbose=True) + await installer.uninstall(verbose=True) def server(on_exit=None): sf = StorageFile(args.storage_file) @@ -448,10 +448,10 @@ def server(on_exit=None): asyncio.run(run_scan_inspect(mode="inspect", args=args)) sys.exit(0) elif args.command == "install": - install() + asyncio.run(install()) sys.exit(0) elif args.command == "uninstall": - uninstall() + asyncio.run(uninstall()) sys.exit(0) elif args.command == "whitelist": if args.reset: @@ -475,7 +475,7 @@ def server(on_exit=None): sys.exit(0) elif args.command == "proxy": args.local_only = True - install() + asyncio.run(install()) print("[Proxy installed, you may need to restart/reload your MCP clients to use it]") server(on_exit=uninstall) else: diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py index 8fef437..eb4f866 100644 --- a/src/mcp_scan/gateway.py +++ b/src/mcp_scan/gateway.py @@ -142,7 +142,7 @@ def __init__( self.paths = paths self.invariant_api_url = invariant_api_url - def install( + async def install( self, gateway_config: MCPGatewayConfig, verbose: bool = False, @@ -150,7 +150,7 @@ def install( for path in self.paths: config: MCPConfig | None = None try: - config = scan_mcp_config_file(path) + config = await scan_mcp_config_file(path) status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" except FileNotFoundError: status = "file does not exist" @@ -188,11 +188,11 @@ def install( with open(os.path.expanduser(path), "w") as f: f.write(config.model_dump_json(indent=4) + "\n") - def uninstall(self, verbose: bool = False) -> None: + async def uninstall(self, verbose: bool = False) -> None: for path in self.paths: config: MCPConfig | None = None try: - config = scan_mcp_config_file(path) + config = await scan_mcp_config_file(path) status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" except FileNotFoundError: status = "file does not exist" @@ -204,7 +204,7 @@ def uninstall(self, verbose: bool = False) -> None: continue path_print_tree = Tree("│") - config = scan_mcp_config_file(path) + config = await scan_mcp_config_file(path) new_servers: dict[str, SSEServer | StdioServer] = {} for name, server in config.get_servers().items(): if isinstance(server, StdioServer): diff --git a/src/mcp_scan_server/activity_logger.py b/src/mcp_scan_server/activity_logger.py index 399c61c..50c18c4 100644 --- a/src/mcp_scan_server/activity_logger.py +++ b/src/mcp_scan_server/activity_logger.py @@ -14,6 +14,12 @@ from textwrap import shorten class ActivityLogger: + """ + Logs trace events as they are received (e.g. tool calls, tool outputs, etc.). + + Ensures that each event is only logged once. Also includes metadata in log output, + like the client, user, server name and tool name. + """ def __init__(self, pretty: Literal["oneline", "compact", "full"] = "compact"): self.cached_metadata = {} # level of pretty printing @@ -27,24 +33,6 @@ def __init__(self, pretty: Literal["oneline", "compact", "full"] = "compact"): self.logged_output = {} # last logged (session_id, tool_call_id), so we can skip logging tool call headers if it is directly followed by output self.last_logged_tool = None - - async def handle_push(self, messages, metadata): - """ - Handles a push request with the given messages and metadata. - """ - for i, batch_items in enumerate(messages): - trace_id = str(uuid.uuid4()) - self.cached_metadata[trace_id] = metadata[i] - await self.log(batch_items, metadata[i]) - - return trace_id - - async def handle_append(self, trace_id: str, messages: list[dict]): - """ - Handles an append request with the given trace ID and messages. - """ - metadata = self.cached_metadata.get(trace_id, {}) - await self.log(messages, metadata) def empty_metadata(self): return { diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index f79b7e6..69619db 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -17,7 +17,6 @@ async def push_trace(request: Request, activity_logger: Annotated[ActivityLogger metadata = body.get("metadata", [{}]) messages = body.get("messages", [[]]) - # trace_id = await activity_logger.handle_push(messages, metadata) trace_id = str(uuid.uuid4()) # return the trace ID diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index 2f2876f..db5b6b0 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -13,6 +13,4 @@ async def append_messages(trace_id: str, request: Request, activity_logger: Anno body = await request.json() messages = body.get("messages", []) - # await activity_logger.handle_append(trace_id, messages) - return {"success": True} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index e0ca6fb..57afca1 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -2,6 +2,7 @@ import uvicorn from fastapi import FastAPI from typing import Literal, Optional +import inspect from mcp_scan_server.activity_logger import setup_activity_logger @@ -50,7 +51,10 @@ async def life_span(self, app: FastAPI): yield if callable(self.on_exit): - self.on_exit() + if inspect.iscoroutinefunction(self.on_exit): + await self.on_exit() + else: + self.on_exit() def run(self): """Run the MCP scan server.""" From 2a2c117786e864b4e22a1318079d8b99cbcae99d Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Wed, 7 May 2025 09:04:02 +0200 Subject: [PATCH 21/21] improve shorthand translation condition --- src/mcp_scan/paths.py | 3 ++- test.sh | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mcp_scan/paths.py b/src/mcp_scan/paths.py index 680b568..7fe3d42 100644 --- a/src/mcp_scan/paths.py +++ b/src/mcp_scan/paths.py @@ -1,4 +1,5 @@ import sys +import re if sys.platform == "linux" or sys.platform == "linux2": # Linux @@ -76,7 +77,7 @@ def client_shorthands_to_paths(shorthands: list[str]): Does nothing if the shorthands are already paths. """ paths = [] - if any("/" in shorthand for shorthand in shorthands): + if any(not re.match(r"^[A-z0-9_-]+$", shorthand) for shorthand in shorthands): return shorthands for shorthand in shorthands: diff --git a/test.sh b/test.sh index 1f1daac..607d275 100644 --- a/test.sh +++ b/test.sh @@ -1,3 +1,3 @@ npx concurrently -p none -k \ - "uv run mcp-scan proxy --pretty full --gateway-dir /Users/luca/Developer/invariant-gateway" \ + "uv run mcp-scan proxy --pretty full --gateway-dir /Users/luca/Developer/invariant-gateway cursor" \ "sleep 0.1 && python test-client.py" \ No newline at end of file