diff --git a/pyproject.toml b/pyproject.toml index d2c353c4..1a6501c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,15 +11,20 @@ classifiers = [ dependencies = [ "mcp[cli]>=1.6.0", "rich>=14.0.0", + "pyjson5>=1.6.8", "aiofiles>=23.1.0", "types-aiofiles", - "pyjson5>=1.6.8", "pydantic>=2.11.2", "lark>=1.1.9", "psutil>=5.9.0", + "invariant-ai>=0.3", + "fastapi>=0.115.12", + "uvicorn>=0.34.2", + "invariant-sdk>=0.0.11", + "pyyaml>=6.0.2", "regex>=2024.11.6", "aiohttp>=3.11.16", - "rapidfuzz>=3.13.0", + "rapidfuzz>=3.13.0" ] [project.scripts] @@ -35,6 +40,7 @@ packages = ["mcp_scan"] [project.optional-dependencies] test = [ "pytest>=7.4.0", + "pytest-lazy-fixtures>=1.1.2", "anyio>=4.0.0" ] dev = [ diff --git a/src/mcp_scan/StorageFile.py b/src/mcp_scan/StorageFile.py index a87c8094..f5da64ce 100644 --- a/src/mcp_scan/StorageFile.py +++ b/src/mcp_scan/StorageFile.py @@ -6,8 +6,11 @@ from datetime import datetime import rich +import yaml # type: ignore from pydantic import ValidationError +from mcp_scan_server.models import GuardrailConfig + from .models import Entity, ScannedEntities, ScannedEntity, entity_type_to_str, hash_entity from .utils import upload_whitelist_entry @@ -19,16 +22,22 @@ class StorageFile: def __init__(self, path: str): logger.debug("Initializing StorageFile with path: %s", path) self.path = os.path.expanduser(path) + logger.debug("Expanded path: %s", self.path) # if path is a file self.scanned_entities: ScannedEntities = ScannedEntities({}) self.whitelist: dict[str, str] = {} + self.guardrails_config: GuardrailConfig = GuardrailConfig({}) - if os.path.isfile(path): - msg = f"Legacy storage file detected at {path}, converting to new format" - logger.info(msg) - rich.print(f"[bold]{msg}[/bold]") + if os.path.isfile(self.path): + rich.print(f"[bold]Legacy storage file detected at {self.path}, converting to new format[/bold]") # legacy format + with open(self.path, "r") as f: + legacy_data = json.load(f) + if "__whitelist" in legacy_data: + self.whitelist = legacy_data["__whitelist"] + del legacy_data["__whitelist"] + try: logger.debug("Loading legacy format file") with open(path) as f: @@ -52,6 +61,7 @@ def __init__(self, path: str): if os.path.exists(path) and os.path.isdir(path): logger.debug("Path exists and is a directory: %s", path) scanned_entities_path = os.path.join(path, "scanned_entities.json") + if os.path.exists(scanned_entities_path): logger.debug("Loading scanned entities from: %s", scanned_entities_path) with open(scanned_entities_path) as f: @@ -69,6 +79,23 @@ def __init__(self, path: str): self.whitelist = json.load(f) logger.info("Successfully loaded whitelist with %d entries", len(self.whitelist)) + guardrails_config_path = os.path.join(self.path, "guardrails_config.yml") + if os.path.exists(guardrails_config_path): + with open(guardrails_config_path, "r") as f: + try: + guardrails_config_data = yaml.safe_load(f.read()) or {} + self.guardrails_config = GuardrailConfig.model_validate(guardrails_config_data) + except yaml.YAMLError as e: + rich.print( + f"[bold red]Could not parse guardrails config file " + f"{guardrails_config_path}: {e}[/bold red]" + ) + except ValidationError as e: + rich.print( + f"[bold red]Could not validate guardrails config file " + f"{guardrails_config_path}: {e}[/bold red]" + ) + def reset_whitelist(self) -> None: logger.info("Resetting whitelist") self.whitelist = {} @@ -139,6 +166,19 @@ def is_whitelisted(self, entity: Entity) -> bool: logger.debug("Checking if entity %s is whitelisted: %s", entity.name, result) return result + def create_guardrails_config(self) -> str: + """ + If the guardrails config file does not exist, create it with default values. + + Returns the path to the guardrails config file. + """ + guardrails_config_path = os.path.join(self.path, "guardrails_config.yml") + if not os.path.exists(guardrails_config_path): + with open(guardrails_config_path, "w") as f: + if self.guardrails_config is not None: + f.write(self.guardrails_config.model_dump_yaml()) + return guardrails_config_path + def save(self) -> None: logger.info("Saving storage data to %s", self.path) try: diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index b8f020cc..96d12c45 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -4,13 +4,19 @@ import sys import psutil +import asyncio import rich from rich.logging import RichHandler -from mcp_scan.MCPScanner import MCPScanner -from mcp_scan.printer import print_scan_result -from mcp_scan.StorageFile import StorageFile -from mcp_scan.version import version_info +from mcp_scan.gateway import MCPGatewayConfig, MCPGatewayInstaller +from mcp_scan_server.server import MCPScanServer + +from .MCPScanner import MCPScanner +from .printer import print_scan_result +from .StorageFile import StorageFile +from .version import version_info +from .paths import WELL_KNOWN_MCP_PATHS, client_shorthands_to_paths + # Configure logging to suppress all output by default logging.getLogger().setLevel(logging.CRITICAL + 1) # Higher than any standard level @@ -58,35 +64,6 @@ def get_invoking_name(): def str2bool(v: str) -> bool: return v.lower() in ("true", "1", "t", "y", "yes") - -if sys.platform == "linux" or sys.platform == "linux2": - WELL_KNOWN_MCP_PATHS = [ - "~/.codeium/windsurf/mcp_config.json", # windsurf - "~/.cursor/mcp.json", # cursor - "~/.vscode/mcp.json", # vscode - "~/.config/Code/User/settings.json", # vscode linux - ] -elif sys.platform == "darwin": - # OS X - WELL_KNOWN_MCP_PATHS = [ - "~/.codeium/windsurf/mcp_config.json", # windsurf - "~/.cursor/mcp.json", # cursor - "~/Library/Application Support/Claude/claude_desktop_config.json", # Claude Desktop mac - "~/.vscode/mcp.json", # vscode - "~/Library/Application Support/Code/User/settings.json", # vscode mac - ] -elif sys.platform == "win32": - WELL_KNOWN_MCP_PATHS = [ - "~/.codeium/windsurf/mcp_config.json", # windsurf - "~/.cursor/mcp.json", # cursor - "~/AppData/Roaming/Claude/claude_desktop_config.json", # Claude Desktop windows - "~/.vscode/mcp.json", # vscode - "~/AppData/Roaming/Code/User/settings.json", # vscode windows - ] -else: - WELL_KNOWN_MCP_PATHS = [] - - def add_common_arguments(parser): """Add arguments that are common to multiple commands.""" parser.add_argument( @@ -140,9 +117,88 @@ def add_server_arguments(parser): help="Suppress stdout/stderr from MCP servers (default: True)", metavar="BOOL", ) + server_group.add_argument( + "--pretty", + type=str, + default="oneline", + choices=["oneline", "compact", "full"], + help="Pretty print the output (default: compact)", + ) + + +def add_install_arguments(parser): + parser.add_argument( + "files", + type=str, + nargs="*", + default=WELL_KNOWN_MCP_PATHS, + help=( + "Different file locations to scan. " + "This can include custom file locations as long as " + "they are in an expected format, including Claude, " + "Cursor or VSCode format." + ), + ) + parser.add_argument( + "--project_name", + type=str, + default="mcp-gateway", + help="Project name for the Invariant Gateway", + ) + parser.add_argument( + "--api-key", + type=str, + help="API key for the Invariant Gateway", + ) + parser.add_argument( + "--local-only", + default=False, + action="store_true", + help="Prevent pushing traces to the explorer.", + ) + parser.add_argument( + "--gateway-dir", + type=str, + help="Source directory for the Invariant Gateway. Set this, if you want to install a custom gateway implementation. (default: the published package is used).", + default=None, + ) + parser.add_argument( + "--mcp-scan-server-port", + type=int, + default=8000, + help="MCP scan server port (default: 8000).", + metavar="PORT", + ) -async def main(): +def add_uninstall_arguments(parser): + parser.add_argument( + "files", + type=str, + nargs="*", + default=WELL_KNOWN_MCP_PATHS, + help=( + "Different file locations to scan. " + "This can include custom file locations as long as " + "they are in an expected format, including Claude, Cursor or VSCode format." + ), + ) + + +def check_install_args(args): + if args.command == "install" and not args.local_only and not args.api_key: + # prompt for api key + print("To install mcp-scan with remote logging, you need an Invariant API key (https://explorer.invariantlabs.ai/settings).\n") + args.api_key = input("API key (or just press enter to install with --local-only): ") + if not args.api_key: + args.local_only = True + + # raise argparse.ArgumentError( + # None, "argument --api-key is required when --local-only is not set" + # ) + + +def main(): # Create main parser with description program_name = get_invoking_name() parser = argparse.ArgumentParser( @@ -218,7 +274,8 @@ async def main(): "whitelist", help="Manage the whitelist of approved entities", description=( - "View, add, or reset whitelisted entities. Whitelisted entities bypass security checks during scans." + "View, add, or reset whitelisted entities. " + "Whitelisted entities bypass security checks during scans." ), ) add_common_arguments(whitelist_parser) @@ -262,6 +319,15 @@ async def main(): help="Hash of the entity to whitelist", metavar="HASH", ) + # install + install_parser = subparsers.add_parser("install", help="Install Invariant Gateway") + add_install_arguments(install_parser) + + # uninstall + uninstall_parser = subparsers.add_parser( + "uninstall", help="Uninstall Invariant Gateway" + ) + add_uninstall_arguments(uninstall_parser) # HELP command help_parser = subparsers.add_parser( # noqa: F841 @@ -270,13 +336,81 @@ async def main(): description="Display detailed help information and examples.", ) + # SERVER command + server_parser = subparsers.add_parser("server", help="Start the MCP scan server") + server_parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + metavar="PORT", + ) + add_common_arguments(server_parser) + + # PROXY command + proxy_parser = subparsers.add_parser( + "proxy", help="Installs and proxies MCP requests, uninstalls on exit" + ) + proxy_parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + metavar="PORT", + ) + add_common_arguments(proxy_parser) + add_server_arguments(proxy_parser) + add_install_arguments(proxy_parser) + # Parse arguments (default to 'scan' if no command provided) args = parser.parse_args(["scan"] if len(sys.argv) == 1 else None) + + # postprocess the files argument (if shorthands are used) + args.files = client_shorthands_to_paths(args.files) # Display version banner - if not args.json: + if not (hasattr(args, "json") and args.json): rich.print(f"[bold blue]Invariant MCP-scan v{version_info}[/bold blue]\n") + async def install(): + try: + check_install_args(args) + except argparse.ArgumentError as e: + parser.error(e) + + invariant_api_url = ( + f"http://localhost:{args.mcp_scan_server_port}" + if args.local_only + else "https://explorer.invariantlabs.ai" + ) + installer = MCPGatewayInstaller( + paths=args.files, invariant_api_url=invariant_api_url + ) + await installer.install( + gateway_config=MCPGatewayConfig( + project_name=args.project_name, + push_explorer=True, + api_key=args.api_key or "", + source_dir=args.gateway_dir, + ), + verbose=True, + ) + + async def uninstall(): + installer = MCPGatewayInstaller(paths=args.files) + await installer.uninstall(verbose=True) + + def server(on_exit=None): + sf = StorageFile(args.storage_file) + guardrails_config_path = sf.create_guardrails_config() + mcp_scan_server = MCPScanServer( + port=args.port, + config_file_path=guardrails_config_path, + on_exit=on_exit, + pretty=args.pretty + ) + mcp_scan_server.run() + # Set up logging if verbose flag is enabled setup_logging(args.verbose or False) @@ -290,7 +424,9 @@ async def main(): sf.reset_whitelist() rich.print("[bold]Whitelist reset[/bold]") sys.exit(0) - elif all(x is None for x in [args.type, args.name, args.hash]): # no args + elif all( + map(lambda x: x is None, [args.type, args.name, args.hash]) + ): # no args sf.print_whitelist() sys.exit(0) elif all(x is not None for x in [args.type, args.name, args.hash]): @@ -303,11 +439,19 @@ async def main(): sf.print_whitelist() sys.exit(0) else: - rich.print("[bold red]Please provide all three parameters: type, name, and hash.[/bold red]") + rich.print( + "[bold red]Please provide all three parameters: type, name, and hash.[/bold red]" + ) whitelist_parser.print_help() sys.exit(1) elif args.command == "inspect": - await run_scan_inspect(mode="inspect", args=args) + asyncio.run(run_scan_inspect(mode="inspect", args=args)) + sys.exit(0) + elif args.command == "install": + asyncio.run(install()) + sys.exit(0) + elif args.command == "uninstall": + asyncio.run(uninstall()) sys.exit(0) elif args.command == "whitelist": if args.reset: @@ -324,8 +468,16 @@ async def main(): rich.print("[bold red]Please provide a name and hash.[/bold red]") sys.exit(1) elif args.command == "scan" or args.command is None: # default to scan - await run_scan_inspect(args=args) + asyncio.run(run_scan_inspect(args=args)) + sys.exit(0) + elif args.command == "server": + server() sys.exit(0) + elif args.command == "proxy": + args.local_only = True + asyncio.run(install()) + print("[Proxy installed, you may need to restart/reload your MCP clients to use it]") + server(on_exit=uninstall) else: # This shouldn't happen due to argparse's handling rich.print(f"[bold red]Unknown command: {args.command}[/bold red]") @@ -341,7 +493,11 @@ async def run_scan_inspect(mode="scan", args=None): elif mode == "inspect": result = await scanner.inspect() if args.json: - result = {r.path: r.model_dump() for r in result} + result = dict((r.path, r.model_dump()) for r in result) print(json.dumps(result, indent=2)) else: - print_scan_result(result, args.print_errors) + print_scan_result(result) + + +if __name__ == "__main__": + main() diff --git a/src/mcp_scan/gateway.py b/src/mcp_scan/gateway.py new file mode 100644 index 00000000..eb4f8664 --- /dev/null +++ b/src/mcp_scan/gateway.py @@ -0,0 +1,233 @@ +import argparse +import os +from typing import Optional + +import rich +from pydantic import BaseModel +from rich.text import Text +from rich.tree import Tree + +from mcp_scan.mcp_client import scan_mcp_config_file +from mcp_scan.models import MCPConfig, SSEServer, StdioServer +from mcp_scan.paths import get_client_from_path +from mcp_scan.printer import format_path_line + +parser = argparse.ArgumentParser( + description="MCP-scan CLI", + prog="invariant-gateway@latest mcp", +) + +parser.add_argument("--exec", type=str, required=True, nargs=argparse.REMAINDER) + + +class MCPServerIsNotGateway(Exception): + pass + + +class MCPServerAlreadyGateway(Exception): + pass + + +class MCPGatewayConfig(BaseModel): + project_name: str + push_explorer: bool + api_key: str + + # the source directory of the gateway implementation to use + # (if None, uses the published package) + source_dir: Optional[str] = None + + +def is_invariant_installed(server: StdioServer) -> bool: + if server.args is None: + return False + if not server.args: + return False + return any("invariant-gateway" in a for a in server.args) + + +def install_gateway( + server: StdioServer, + config: MCPGatewayConfig, + invariant_api_url: str = "https://explorer.invariantlabs.ai", + extra_metadata: dict[str, str] = {}, +) -> StdioServer: + """Install the gateway for the given server.""" + if is_invariant_installed(server): + raise MCPServerAlreadyGateway() + + env = server.env | {"INVARIANT_API_KEY": config.api_key or "", "INVARIANT_API_URL": invariant_api_url, "GUARDRAILS_API_URL": invariant_api_url} + + cmd = "uvx" + base_args = [ + "invariant-gateway@latest", + "mcp", + ] + + # if running gateway from source-dir, use 'uv run' instead + if config.source_dir: + cmd = "uv" + base_args = [ + "run", + "--with", + "mcp", + "--directory", + config.source_dir, + "invariant-gateway", + "mcp" + ] + + flags = [ + "--project-name", + config.project_name, + *(["--push-explorer"] if config.push_explorer else []), + ] + # add extra metadata flags + for k, v in extra_metadata.items(): + flags.append(f"--metadata-{k}={v}") + + # add exec section + flags += [ + *["--exec", server.command], + *(server.args if server.args else []) + ] + + # return new server config + return StdioServer( + command=cmd, + args=base_args + flags, + env=env + ) + + +def uninstall_gateway( + server: StdioServer, +) -> StdioServer: + """Uninstall the gateway for the given server.""" + if not is_invariant_installed(server): + raise MCPServerIsNotGateway() + + assert isinstance(server.args, list), "args is not a list" + args, unknown = parser.parse_known_args(server.args[2:]) + new_env = {k: v for k, v in server.env.items() if k != "INVARIANT_API_KEY" and k != "INVARIANT_API_URL" and k != "GUARDRAILS_API_URL"} + assert args.exec is not None, "exec is None" + assert args.exec, "exec is empty" + return StdioServer( + command=args.exec[0], + args=args.exec[1:], + env=new_env, + ) + + +def format_install_line(server: str, status: str, success: bool | None) -> Text: + color = {True: "[green]", False: "[red]", None: "[gray62]"}[success] + + if len(server) > 25: + server = server[:22] + "..." + server = server + " " * (25 - len(server)) + icon = {True: ":white_heavy_check_mark:", False: ":cross_mark:", None: ""}[success] + + text = f"{color}[bold]{server}[/bold]{icon} {status}{color.replace('[', '[/')}" + return Text.from_markup(text) + + +class MCPGatewayInstaller: + """A class to install and uninstall the gateway for a given server.""" + + def __init__( + self, + paths: list[str], + invariant_api_url: str = "https://explorer.invariantlabs.ai", + ) -> None: + self.paths = paths + self.invariant_api_url = invariant_api_url + + async def install( + self, + gateway_config: MCPGatewayConfig, + verbose: bool = False, + ) -> None: + for path in self.paths: + config: MCPConfig | None = None + try: + config = await scan_mcp_config_file(path) + status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" + except FileNotFoundError: + status = "file does not exist" + except Exception: + status = "could not parse file" + if verbose: + rich.print(format_path_line(path, status, operation="Installing Gateway")) + if config is None: + continue + + path_print_tree = Tree("│") + new_servers: dict[str, SSEServer | StdioServer] = {} + for name, server in config.get_servers().items(): + if isinstance(server, StdioServer): + try: + new_servers[name] = install_gateway(server, gateway_config, self.invariant_api_url, {"server": name, "client": get_client_from_path(path)}) + path_print_tree.add(format_install_line(server=name, status="Installed", success=True)) + except MCPServerAlreadyGateway: + new_servers[name] = server + path_print_tree.add(format_install_line(server=name, status="Already installed", success=True)) + except Exception as e: + new_servers[name] = server + print(f"Failed to install gateway for {name}", e) + path_print_tree.add(format_install_line(server=name, status="Failed to install", success=False)) + + else: + new_servers[name] = server + path_print_tree.add( + format_install_line(server=name, status="sse servers not supported yet", success=False) + ) + + if verbose: + rich.print(path_print_tree) + config.set_servers(new_servers) + with open(os.path.expanduser(path), "w") as f: + f.write(config.model_dump_json(indent=4) + "\n") + + async def uninstall(self, verbose: bool = False) -> None: + for path in self.paths: + config: MCPConfig | None = None + try: + config = await scan_mcp_config_file(path) + status = f"found {len(config.get_servers())} server{'' if len(config.get_servers()) == 1 else 's'}" + except FileNotFoundError: + status = "file does not exist" + except Exception: + status = "could not parse file" + if verbose: + rich.print(format_path_line(path, status, operation="Installing Gateway")) + if config is None: + continue + + path_print_tree = Tree("│") + config = await scan_mcp_config_file(path) + new_servers: dict[str, SSEServer | StdioServer] = {} + for name, server in config.get_servers().items(): + if isinstance(server, StdioServer): + try: + new_servers[name] = uninstall_gateway(server) + path_print_tree.add(format_install_line(server=name, status="Uninstalled", success=True)) + except MCPServerIsNotGateway: + new_servers[name] = server + path_print_tree.add( + format_install_line(server=name, status="Already not installed", success=True) + ) + except Exception: + new_servers[name] = server + path_print_tree.add( + format_install_line(server=name, status="Failed to uninstall", success=False) + ) + else: + new_servers[name] = server + path_print_tree.add( + format_install_line(server=name, status="sse servers not supported yet", success=None) + ) + config.set_servers(new_servers) + if verbose: + rich.print(path_print_tree) + with open(os.path.expanduser(path), "w") as f: + f.write(config.model_dump_json(indent=4) + "\n") diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index bff4e687..5e982f2f 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -3,8 +3,8 @@ import os from typing import AsyncContextManager # noqa: UP035 -import aiofiles # type: ignore import pyjson5 +import aiofiles from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client @@ -145,4 +145,4 @@ def parse_and_validate(config: dict) -> MCPConfig: return result except Exception: logger.exception("Error processing config file") - raise + raise \ No newline at end of file diff --git a/src/mcp_scan/paths.py b/src/mcp_scan/paths.py new file mode 100644 index 00000000..7fe3d424 --- /dev/null +++ b/src/mcp_scan/paths.py @@ -0,0 +1,88 @@ +import sys +import re + +if sys.platform == "linux" or sys.platform == "linux2": + # Linux + CLIENT_PATHS = { + 'windsurf': [ + "~/.codeium/windsurf/mcp_config.json" + ], + 'cursor': [ + "~/.cursor/mcp.json" + ], + 'vscode': [ + "~/.vscode/mcp.json", + "~/.config/Code/User/settings.json" + ], + } + WELL_KNOWN_MCP_PATHS = [path for client, paths in CLIENT_PATHS.items() for path in paths] +elif sys.platform == "darwin": + # OS X + CLIENT_PATHS = { + 'windsurf': [ + "~/.codeium/windsurf/mcp_config.json" + ], + 'cursor': [ + "~/.cursor/mcp.json" + ], + 'claude': [ + "~/Library/Application Support/Claude/claude_desktop_config.json" + ], + 'vscode': [ + "~/.vscode/mcp.json", + "~/Library/Application Support/Code/User/settings.json" + ], + } + WELL_KNOWN_MCP_PATHS = [path for client, paths in CLIENT_PATHS.items() for path in paths] +elif sys.platform == "win32": + CLIENT_PATHS = { + 'windsurf': [ + "~/.codeium/windsurf/mcp_config.json" + ], + 'cursor': [ + "~/.cursor/mcp.json" + ], + 'claude': [ + "~/AppData/Roaming/Claude/claude_desktop_config.json" + ], + 'vscode': [ + "~/.vscode/mcp.json", + "~/AppData/Roaming/Code/User/settings.json" + ], + } + + WELL_KNOWN_MCP_PATHS = [path for client, paths in CLIENT_PATHS.items() for path in paths] +else: + WELL_KNOWN_MCP_PATHS = [] + +def get_client_from_path(path: str) -> str: + """ + Returns the client name from a path. + + Args: + path (str): The path to get the client from. + + Returns: + str: The client name or None if it cannot be guessed from the path. + """ + for client, paths in CLIENT_PATHS.items(): + if path in paths: + return client + return None + +def client_shorthands_to_paths(shorthands: list[str]): + """ + Converts a list of client shorthands to a list of paths. + + Does nothing if the shorthands are already paths. + """ + paths = [] + if any(not re.match(r"^[A-z0-9_-]+$", shorthand) for shorthand in shorthands): + return shorthands + + for shorthand in shorthands: + if shorthand in CLIENT_PATHS: + paths.extend(CLIENT_PATHS[shorthand]) + else: + raise ValueError(f"{shorthand} is not a valid client shorthand") + return paths diff --git a/src/mcp_scan_server/__init__.py b/src/mcp_scan_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mcp_scan_server/activity_logger.py b/src/mcp_scan_server/activity_logger.py new file mode 100644 index 00000000..50c18c43 --- /dev/null +++ b/src/mcp_scan_server/activity_logger.py @@ -0,0 +1,120 @@ +from typing import Literal +import uuid +import rich +import json + +from fastapi import APIRouter, FastAPI, Request +from invariant_sdk.types.push_traces import PushTracesResponse + +from rich import print +from rich.panel import Panel +from rich.syntax import Syntax +from rich.markdown import Markdown +from rich.rule import Rule +from textwrap import shorten + +class ActivityLogger: + """ + Logs trace events as they are received (e.g. tool calls, tool outputs, etc.). + + Ensures that each event is only logged once. Also includes metadata in log output, + like the client, user, server name and tool name. + """ + def __init__(self, pretty: Literal["oneline", "compact", "full"] = "compact"): + self.cached_metadata = {} + # level of pretty printing + self.pretty = pretty + + # (session_id, tool_call_id) -> bool + self.logged_header = {} + self.logged_result = {} + + # (session_id, formatted_output) -> bool + self.logged_output = {} + # last logged (session_id, tool_call_id), so we can skip logging tool call headers if it is directly followed by output + self.last_logged_tool = None + + def empty_metadata(self): + return { + "client": "Unknown Client", + "mcp_server": "Unknown Server", + "user": None + } + + async def log(self, messages, metadata): + """ + Console-logs the relevant parts of the given messages and metadata. + """ + session_id = metadata.get("session_id", "") + client = metadata.get("client", "Unknown Client").capitalize() + server = metadata.get("mcp_server", "Unknown Server").capitalize() + user = metadata.get("user", None) + + tool_names = {} + + for msg in messages: + if msg.get('role') == 'tool': + if (session_id, 'output-' + msg.get('tool_call_id')) in self.logged_output: + continue + self.logged_output[(session_id, 'output-' + msg.get('tool_call_id'))] = True + + has_header = self.last_logged_tool == (session_id, msg.get('tool_call_id')) + + if not has_header: + print(Rule()) + # left arrow for output + user_portion = "" if user is None else f" ([bold red]{user}[/bold red])" + name = tool_names.get(msg.get('tool_call_id'), "") + print(f"← [bold blue]{client}[/bold blue]{user_portion} used [bold green]{server}[/bold green] to [bold green]{name}[/bold green]") + print(Rule()) + + # tool output + content = message_content(msg) + if type(content) is str and content.startswith("Error"): + print(Syntax(content, "pytb", theme="monokai")) + else: + print(Syntax(content, "json", theme="monokai")) + print(Rule()) + + else: + for tc in (msg.get('tool_calls') or []): + name = tc.get('function', {}).get('name', "") + tool_names[tc.get('id')] = name + + if (session_id, tc.get('id')) in self.logged_output: + continue + self.logged_output[(session_id, tc.get('id'))] = True + + self.last_logged_tool = (session_id, tc.get('id')) + + # header + user_portion = "" if user is None else f" ([bold red]{user}[/bold red])" + + print(Rule()) + print(f"→ [bold blue]{client}[/bold blue]{user_portion} used [bold green]{server}[/bold green] to [bold green]{name}[/bold green]") + print(Rule()) + + # tool arguments + print(Syntax(json.dumps(tc.get('arguments', {}), indent=2), "json", theme="monokai")) + + +def message_content(msg: dict) -> str: + if type(msg.get('content')) is str: + return msg.get('content', '') + elif type(msg.get('content')) is list: + return "\n".join([c['text'] for c in msg.get('content', []) if c['type'] == 'text']) + else: + return "" + +async def get_activity_logger(request: Request) -> ActivityLogger: + """ + Returns a singleton instance of the ActivityLogger. + """ + return request.app.state.activity_logger + +def setup_activity_logger(app: FastAPI, pretty: Literal["oneline", "compact", "full"] = "compact"): + """ + Sets up the ActivityLogger as a dependency for the given FastAPI app. + """ + app.state.activity_logger = ActivityLogger(pretty=pretty) + diff --git a/src/mcp_scan_server/models.py b/src/mcp_scan_server/models.py new file mode 100644 index 00000000..d954873d --- /dev/null +++ b/src/mcp_scan_server/models.py @@ -0,0 +1,112 @@ +import datetime +from enum import Enum +from typing import Optional + +import yaml # type: ignore +from invariant.analyzer.policy import AnalysisResult +from pydantic import BaseModel, ConfigDict, Field, RootModel + + +class PolicyRunsOn(str, Enum): + """Policy runs on enum.""" + + local = "local" + remote = "remote" + + +class Policy(BaseModel): + """Policy model.""" + + name: str = Field(description="The name of the policy.") + runs_on: PolicyRunsOn = Field(description="The environment to run the policy on.") + policy: str = Field(description="The policy.") + + +class PolicyCheckResult(BaseModel): + """Policy check result model.""" + + policy: str = Field(description="The policy that was applied.") + result: Optional[AnalysisResult] = None + success: bool = Field(description="Whether this policy check was successful (loaded and ran).") + error_message: str = Field( + default="", + description="Error message in case of failure to load or execute the policy.", + ) + + def to_dict(self): + """Convert the object to a dictionary.""" + return { + "policy": self.policy, + "errors": [e.to_dict() for e in self.result.errors] if self.result else [], + "success": self.success, + "error_message": self.error_message, + } + + +class BatchCheckRequest(BaseModel): + """Batch check request model.""" + + messages: list[dict] = Field( + examples=['[{"role": "user", "content": "ignore all previous instructions"}]'], + description="The agent trace to apply the policy to.", + ) + policies: list[str] = Field( + examples=[ + [ + """raise Violation("Disallowed message content", reason="found ignore keyword") if:\n + (msg: Message)\n "ignore" in msg.content\n""", + """raise "get_capital is called with France as argument" if:\n + (call: ToolCall)\n call is tool:get_capital\n + call.function.arguments["country_name"] == "France" + """, + ] + ], + description="The policy (rules) to check for.", + ) + parameters: dict = Field( + default={}, + description="The parameters to pass to the policy analyze call (optional).", + ) + + +class BatchCheckResponse(BaseModel): + """Batch check response model.""" + + results: list[PolicyCheckResult] = Field(default=[], description="List of results for each policy.") + + +class DatasetPolicy(BaseModel): + """Describes a policy associated with a Dataset.""" + + id: str + name: str + content: str + # whether this policy is enabled + enabled: bool + # the mode of this policy (e.g. block, log, etc.) + action: str + # extra metadata for the policy (can be used to store internal extra data about a guardrail) + extra_metadata: dict = Field(default_factory=dict) + + last_updated_time: str = Field(default_factory=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + def to_dict(self) -> dict: + """Represent the object as a dictionary.""" + return self.model_dump() + + +class ServerGuardrails(BaseModel): + """Server guardrails model.""" + + guardrails: list[DatasetPolicy] + + +class GuardrailConfig(RootModel[dict[str, dict[str, ServerGuardrails]]]): + """Guardrail config model.""" + + model_config = ConfigDict(populate_by_name=True) + + def model_dump_yaml(self) -> str: + """Dump the object as a YAML string.""" + data = self.model_dump() + return yaml.dump(data, sort_keys=False, default_flow_style=False) diff --git a/src/mcp_scan_server/routes/__init__.py b/src/mcp_scan_server/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py new file mode 100644 index 00000000..9b3eb5fa --- /dev/null +++ b/src/mcp_scan_server/routes/policies.py @@ -0,0 +1,152 @@ +import asyncio +import json +import os + +import fastapi +import rich +import yaml # type: ignore +from fastapi import APIRouter, Depends, Request +from invariant.analyzer.policy import LocalPolicy +from invariant.analyzer.runtime.runtime_errors import ( + ExcessivePolicyError, + InvariantAttributeError, + MissingPolicyParameter, +) +from pydantic import ValidationError + +from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger + +from ..models import ( + BatchCheckRequest, + BatchCheckResponse, + DatasetPolicy, + GuardrailConfig, + PolicyCheckResult, +) + +router = APIRouter() + + +async def get_all_policies(config_file_path: str) -> list[DatasetPolicy]: + """Get all policies from local config file.""" + if not os.path.exists(config_file_path): + rich.print( + f"""[bold red]Guardrail config file not found: {config_file_path}. Creating an empty one.[/bold red]""" + ) + config = GuardrailConfig.model_validate({}) + with open(config_file_path, "w") as f: + f.write(config.model_dump_yaml()) + + with open(config_file_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + try: + config = GuardrailConfig.model_validate(config) + except ValidationError as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) + + policies = [ + DatasetPolicy( + id=guardrail.id, + name=guardrail.name, + content=guardrail.content, + enabled=guardrail.enabled, + action=guardrail.action, + extra_metadata={ + "platform": platform_name, + "tool": tool_name, + }, + ) + for platform_name, platform_data in config.root.items() + for tool_name, server_guardrails in platform_data.items() + for guardrail in server_guardrails.guardrails + ] + + return policies + + +@router.get("/dataset/byuser/{username}/{dataset_name}/policy") +async def get_policy(username: str, dataset_name: str, request: Request): + """Get a policy from local config file.""" + policies = await get_all_policies(request.app.state.config_file_path) + return {"policies": policies} + + +async def check_policy(policy_str: str, messages: list[dict], parameters: dict = {}) -> PolicyCheckResult: + """ + Check a policy using the invariant analyzer. + + Args: + policy_str: The policy to check. + messages: The messages to check the policy against. + parameters: The parameters to pass to the policy analyze call. + + Returns: + A PolicyCheckResult object. + """ + try: + policy = LocalPolicy.from_string(policy_str) + + if isinstance(policy, Exception): + return PolicyCheckResult( + policy=policy_str, + success=False, + error_message=str(policy), + ) + + result = await policy.a_analyze_pending(messages[:-1], [messages[-1]], **parameters) + + return PolicyCheckResult( + policy=policy_str, + result=result, + success=True, + ) + + except (MissingPolicyParameter, ExcessivePolicyError, InvariantAttributeError) as e: + return PolicyCheckResult( + policy=policy_str, + success=False, + error_message=str(e), + ) + except Exception as e: + return PolicyCheckResult( + policy=policy_str, + success=False, + error_message="Unexpected error: " + str(e), + ) + + +def to_json_serializable_dict(obj): + """Convert a dictionary to a JSON serializable dictionary.""" + if isinstance(obj, dict): + return {k: to_json_serializable_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [to_json_serializable_dict(v) for v in obj] + elif isinstance(obj, str): + return obj + elif isinstance(obj, (int, float, bool)): + return obj + else: + return type(obj).__name__ + "(" + str(obj) + ")" + + +@router.post("/policy/check/batch", response_model=BatchCheckResponse) +async def batch_check_policies(check_request: BatchCheckRequest, request: fastapi.Request, activity_logger: ActivityLogger = Depends(get_activity_logger)): + """Check a policy using the invariant analyzer.""" + results = await asyncio.gather( + *[check_policy(policy, check_request.messages, check_request.parameters) for policy in check_request.policies] + ) + + metadata = check_request.parameters.get("metadata", {}) + + await activity_logger.log(check_request.messages, { + "client": metadata.get("client"), + "mcp_server": metadata.get("server"), + "user": metadata.get("system_user"), + "session_id": metadata.get("session_id"), + }) + + return fastapi.responses.JSONResponse( + content={"result": [to_json_serializable_dict(result.to_dict()) for result in results]} + ) diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py new file mode 100644 index 00000000..69619dbb --- /dev/null +++ b/src/mcp_scan_server/routes/push.py @@ -0,0 +1,23 @@ +import uuid +import rich +import json + +from fastapi import APIRouter, Request, Depends +from typing import Annotated +from invariant_sdk.types.push_traces import PushTracesResponse + +from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger + +router = APIRouter() + +@router.post("/trace") +async def push_trace(request: Request, activity_logger: Annotated[ActivityLogger, Depends(get_activity_logger)]) -> PushTracesResponse: + """Push a trace. For now, this is a dummy response.""" + body = await request.json() + metadata = body.get("metadata", [{}]) + messages = body.get("messages", [[]]) + + trace_id = str(uuid.uuid4()) + + # return the trace ID + return PushTracesResponse(id=[trace_id], success=True) diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py new file mode 100644 index 00000000..db5b6b0d --- /dev/null +++ b/src/mcp_scan_server/routes/trace.py @@ -0,0 +1,16 @@ +from fastapi import APIRouter, Request, Depends +from typing import Annotated + +from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger + +router = APIRouter() + + +@router.post("/{trace_id}/messages") +async def append_messages(trace_id: str, request: Request, activity_logger: Annotated[ActivityLogger, Depends(get_activity_logger)]): + """Append messages to a trace. For now this is a dummy response.""" + + body = await request.json() + messages = body.get("messages", []) + + return {"success": True} diff --git a/src/mcp_scan_server/routes/user.py b/src/mcp_scan_server/routes/user.py new file mode 100644 index 00000000..3e16cce8 --- /dev/null +++ b/src/mcp_scan_server/routes/user.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +router = APIRouter() + + +@router.get("/identity") +async def identity(): + """Get the identity of the user. For now, this is a dummy response.""" + return {"username": "user"} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py new file mode 100644 index 00000000..57afca19 --- /dev/null +++ b/src/mcp_scan_server/server.py @@ -0,0 +1,61 @@ +import rich +import uvicorn +from fastapi import FastAPI +from typing import Literal, Optional +import inspect + +from mcp_scan_server.activity_logger import setup_activity_logger + +from .routes.policies import router as policies_router +from .routes.push import router as push_router +from .routes.trace import router as dataset_trace_router +from .routes.user import router as user_router + + +class MCPScanServer: + """ + MCP Scan Server. + + Args: + port: The port to run the server on. + config_file_path: The path to the config file. + on_exit: A callback function to be called on exit of the server. + log_level: The log level for the server. + """ + + def __init__(self, port: int = 8000, config_file_path: str | None = None, on_exit: Optional[callable] = None, log_level: str = "error", pretty: Literal["oneline", "compact", "full"] = "compact"): + self.port = port + self.config_file_path = config_file_path + self.on_exit = on_exit + self.log_level = log_level + self.pretty = pretty + + self.app = FastAPI(lifespan=self.life_span) + self.app.state.config_file_path = config_file_path + + self.app.include_router(policies_router, prefix="/api/v1") + self.app.include_router(push_router, prefix="/api/v1/push") + self.app.include_router(dataset_trace_router, prefix="/api/v1/trace") + self.app.include_router(user_router, prefix="/api/v1/user") + + async def on_startup(self): + """Startup event for the FastAPI app.""" + rich.print("[bold green]MCP-scan server started.[/bold green]") + + setup_activity_logger(self.app, pretty=self.pretty) + + async def life_span(self, app: FastAPI): + """Lifespan event for the FastAPI app.""" + await self.on_startup() + + yield + + if callable(self.on_exit): + if inspect.iscoroutinefunction(self.on_exit): + await self.on_exit() + else: + self.on_exit() + + def run(self): + """Run the MCP scan server.""" + uvicorn.run(self.app, host="0.0.0.0", port=self.port, log_level=self.log_level) diff --git a/test-client.py b/test-client.py new file mode 100644 index 00000000..dbf21d19 --- /dev/null +++ b/test-client.py @@ -0,0 +1,34 @@ +import asyncio +import json +from pathlib import Path +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.stdio import stdio_client + +CONFIG_PATH = Path("/Users/luca/.cursor/mcp.json") +SERVER_KEY = "whatsapp-mcp" + +def load_server_params(key: str) -> StdioServerParameters: + config = json.loads(CONFIG_PATH.read_text()) + server_cfg = config["mcpServers"][key] + return StdioServerParameters( + command=server_cfg["command"], + args=server_cfg.get("args", []), + env=server_cfg.get("env", {}) + ) + +async def run(): + await asyncio.sleep(1) + server_params = load_server_params(SERVER_KEY) + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + await session.list_tools() + + await session.call_tool("list_chats", arguments={"limit": 20, "include_last_message": True, "sort_by": "last_active"}), + await asyncio.sleep(0.2) + + await session.call_tool("send_message", arguments={"chat_id": "123", "message": "Hello, world!"}) + await asyncio.sleep(0.2) + +if __name__ == "__main__": + asyncio.run(run()) \ No newline at end of file diff --git a/test.sh b/test.sh new file mode 100644 index 00000000..607d275a --- /dev/null +++ b/test.sh @@ -0,0 +1,3 @@ +npx concurrently -p none -k \ + "uv run mcp-scan proxy --pretty full --gateway-dir /Users/luca/Developer/invariant-gateway cursor" \ + "sleep 0.1 && python test-client.py" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f9cbc451..95d18a02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,12 +3,6 @@ import pytest -@pytest.fixture -def sample_fixture(): - """Sample fixture for demonstration purposes.""" - return "sample_value" - - @pytest.fixture def claudestyle_config(): """Sample Claude-style MCP config.""" diff --git a/tests/e2e/test_full_scan_flow.py b/tests/e2e/test_full_scan_flow.py index cf577280..a3d0d4e0 100644 --- a/tests/e2e/test_full_scan_flow.py +++ b/tests/e2e/test_full_scan_flow.py @@ -4,6 +4,7 @@ import subprocess import pytest +from pytest_lazy_fixtures import lf from mcp_scan.utils import TempFile @@ -35,6 +36,7 @@ def test_basic(self, sample_configs): output = json.loads(result.stdout) assert fn in output except json.JSONDecodeError: + print(result.stdout) pytest.fail("Failed to parse JSON output") def vscode_settings_no_mcp(self): diff --git a/tests/test_configs.json b/tests/test_configs.json new file mode 100644 index 00000000..62f8e350 --- /dev/null +++ b/tests/test_configs.json @@ -0,0 +1,32 @@ +{ + "mcpServers": { + "Random Facts MCP Server": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp[cli]", + "mcp", + "run", + "/Users/marcomilanta/Documents/invariant/mcp-injection-experiments/whatsapp-takeover.py" + ], + "type": "stdio", + "env": {} + }, + "WhatsApp Server": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp[cli]", + "--with", + "requests", + "mcp", + "run", + "/Users/marcomilanta/Documents/invariant/mcp-injection-experiments/whatsapp.py" + ], + "type": "stdio", + "env": {} + } + } +} diff --git a/tests/unit/test_gateway.py b/tests/unit/test_gateway.py new file mode 100644 index 00000000..29f656c3 --- /dev/null +++ b/tests/unit/test_gateway.py @@ -0,0 +1,53 @@ +import os +import tempfile + +import pyjson5 +import pytest +from pytest_lazy_fixtures import lf + +from mcp_scan.gateway import MCPGatewayConfig, MCPGatewayInstaller, is_invariant_installed +from mcp_scan.mcp_client import scan_mcp_config_file +from mcp_scan.models import StdioServer + + +@pytest.fixture +def temp_file(): + with tempfile.NamedTemporaryFile(delete=False) as tf: + yield tf.name + os.remove(tf.name) + + +@pytest.mark.parametrize("sample_config", [lf("claudestyle_config"), lf("vscode_mcp_config"), lf("vscode_config")]) +async def test_install_gateway(sample_config, temp_file): + # TODO iterate over all sample configs + with open(temp_file, "w") as f: + f.write(sample_config) + + config_dict = pyjson5.loads(sample_config) + installer = MCPGatewayInstaller(paths=[temp_file]) + for server in (await scan_mcp_config_file(temp_file)).get_servers().values(): + if isinstance(server, StdioServer): + assert not is_invariant_installed(server), "Invariant should not be installed" + installer.install( + gateway_config=MCPGatewayConfig(project_name="test", push_explorer=True, api_key="my-very-secret-api-key"), + verbose=True, + ) + + # try to load the config + pyjson5.loads(sample_config) + + for server in (await scan_mcp_config_file(temp_file)).get_servers().values(): + if isinstance(server, StdioServer): + assert is_invariant_installed(server), "Invariant should be installed" + + installer.uninstall(verbose=True) + + for server in (await scan_mcp_config_file(temp_file)).get_servers().values(): + if isinstance(server, StdioServer): + assert not is_invariant_installed(server), "Invariant should be uninstalled" + + config_dict_uninstalled = pyjson5.loads(sample_config) + + assert ( + config_dict_uninstalled == config_dict + ), "Installation and uninstallation of the gateway should not change the config file" diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index 08e5db9e..35140526 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from pytest_lazy_fixtures import lf from mcp_scan.mcp_client import check_server, check_server_with_timeout, scan_mcp_config_file from mcp_scan.models import StdioServer @@ -16,7 +17,7 @@ async def test_scan_mcp_config(sample_configs): temp_file.write(config) temp_file.flush() - config = await scan_mcp_config_file(temp_file.name) + await scan_mcp_config_file(temp_file.name) @pytest.mark.anyio diff --git a/tests/unit/test_mcp_scan_server.py b/tests/unit/test_mcp_scan_server.py new file mode 100644 index 00000000..6d3f0fd4 --- /dev/null +++ b/tests/unit/test_mcp_scan_server.py @@ -0,0 +1,215 @@ +import os +from unittest.mock import patch + +import pytest +import yaml # type: ignore +from fastapi import HTTPException +from fastapi.testclient import TestClient + +from mcp_scan_server.models import DatasetPolicy, GuardrailConfig +from mcp_scan_server.routes.policies import check_policy, get_all_policies +from mcp_scan_server.server import MCPScanServer + +client = TestClient(MCPScanServer().app) + + +@pytest.fixture +def valid_guardrail_config_file(tmp_path): + """Fixture that creates a temporary valid config file and returns its path.""" + config_file = tmp_path / "config.yaml" + config_file.write_text( + """ +cursor: + browsermcp: + guardrails: + - name: "Guardrail 1" + id: "guardrail_1" + runs-on: "local" + enabled: true + action: "block" + content: | + raise "error" if: + (msg: ToolOutput) + "Test1" in msg.content + - name: "Guardrail 2" + id: "guardrail_2" + runs-on: "local" + enabled: true + action: "block" + content: | + raise "error" if: + (msg: ToolOutput) + "Test2" in msg.content +""" + ) + return str(config_file) + + +@pytest.fixture +def invalid_guardrail_config_file(tmp_path): + """Fixture that creates a temporary invalid config file and returns its path.""" + config_file = tmp_path / "config.yaml" + config_file.write_text( + """ +cursor: + browsermcp: + guardrails: + - name: "Guardrail 1" + id: "guardrail_1" + runs-on: "local" + enabled: true + action: "block" +""" + ) + return str(config_file) + + +@pytest.mark.anyio +async def test_get_all_policies_valid_config(valid_guardrail_config_file): + """Test that the get_all_policies function returns the correct policies for a valid config file.""" + policies = await get_all_policies(valid_guardrail_config_file) + print(policies) + assert len(policies) == 2 + assert all(isinstance(policy, DatasetPolicy) for policy in policies) + assert policies[0].id == "guardrail_1" + assert policies[1].id == "guardrail_2" + assert policies[0].name == "Guardrail 1" + assert policies[1].name == "Guardrail 2" + + +@pytest.mark.anyio +async def test_get_all_policies_invalid_config(invalid_guardrail_config_file): + """Test that the get_all_policies function raises an HTTPException for an invalid config file.""" + with pytest.raises(HTTPException): + await get_all_policies(invalid_guardrail_config_file) + + +@pytest.mark.anyio +async def test_get_all_policies_creates_file_when_missing(tmp_path): + """Test that get_all_policies creates a config file if it doesn't exist.""" + # Create a path to a non-existent file + config_file_path = str(tmp_path / "nonexistent_config.yaml") + + # Verify the file doesn't exist before calling the function + assert not os.path.exists(config_file_path) + + # Call the function + await get_all_policies(config_file_path) + + # Verify the file now exists + assert os.path.exists(config_file_path) + + # Verify the file contains a valid empty config + with open(config_file_path, "r") as f: + config_content = f.read() + loaded_config = yaml.safe_load(config_content) + + # Validate the config + GuardrailConfig.model_validate(loaded_config) + + +@pytest.mark.anyio +async def mock_get_all_policies(config_file_path: str) -> list[str]: + return ["some_guardrail"] + + +@patch("mcp_scan_server.routes.policies.get_all_policies", mock_get_all_policies) +def test_get_policy_endpoint(): + """Test that the get_policy returns a dict with a list of policies.""" + response = client.get("/api/v1/dataset/byuser/testuser/test_dataset/policy") + assert response.status_code == 200 + assert response.json() == {"policies": ["some_guardrail"]} + + +# fixture policy_str +@pytest.fixture +def error_one_policy_str(): + return """ + raise "error_one" if: + (msg: Message) + "error_one" in msg.content + """ + + +@pytest.fixture +def error_two_policy_str(): + return """ + raise "error_two" if: + (msg: Message) + "error_two" in msg.content + """ + + +@pytest.fixture +def detect_random_policy_str(): + return """ + raise "error_random" if: + (msg: Message) + "random" in msg.content + """ + + +@pytest.fixture +def detect_simple_flow_policy_str(): + return """ + raise "error_flow" if: + (msg1: Message) + (msg2: ToolOutput) + msg1.content == "request_tool" + msg2.content == "tool_output" + """ + + +@pytest.fixture +def simple_trace(): + return [ + {"content": "error_one", "role": "user"}, + {"content": "error_two", "role": "user"}, + ] + + +@pytest.fixture +def simple_flow_trace(): + return [ + {"content": "request_tool", "role": "user"}, + {"content": "some_response", "role": "assistant"}, + {"content": "tool_output", "role": "tool"}, + ] + + +@pytest.mark.anyio +async def test_check_policy_raises_exception_when_trace_violates_policy(error_two_policy_str, simple_trace): + """Test that the check_policy endpoint raises an exception when the trace violates the policy.""" + result = await check_policy(error_two_policy_str, simple_trace) + assert len(result.result.errors) == 1 + assert result.result.errors[0].args[0] == "error_two" + + +@pytest.mark.anyio +async def test_check_policy_only_raises_error_on_last_message(error_one_policy_str, error_two_policy_str, simple_trace): + """Test that the check_policy endpoint only raises an error on the last message.""" + # Should not raise an error as the last message does not contain "error_one" + result_one = await check_policy(error_one_policy_str, simple_trace) + assert len(result_one.result.errors) == 0 + assert result_one.error_message == "" + + # Should raise an error as the last message contains "error_two" + result_two = await check_policy(error_two_policy_str, simple_trace) + assert len(result_two.result.errors) == 1 + assert result_two.result.errors[0].args[0] == "error_two" + + +@pytest.mark.anyio +async def test_check_policy_returns_success_when_trace_does_not_violate_policy(detect_random_policy_str, simple_trace): + """Test that the check_policy endpoint returns success when the trace does not violate the policy.""" + result = await check_policy(detect_random_policy_str, simple_trace) + assert len(result.result.errors) == 0 + assert result.error_message == "" + + +@pytest.mark.anyio +async def test_check_policy_catches_flow_violations(detect_simple_flow_policy_str, simple_flow_trace): + """Test that the check_policy endpoint catches flow violations.""" + result = await check_policy(detect_simple_flow_policy_str, simple_flow_trace) + assert len(result.result.errors) == 1 + assert result.result.errors[0].args[0] == "error_flow"