|
1 | 1 | import ast
|
2 |
| -from typing import TYPE_CHECKING |
| 2 | +import asyncio |
| 3 | +import logging |
3 | 4 |
|
4 | 5 | import aiohttp
|
5 | 6 | from invariant.analyzer.policy import LocalPolicy
|
| 7 | +from mcp.types import Tool |
6 | 8 |
|
7 | 9 | from .models import (
|
8 | 10 | EntityScanResult,
|
| 11 | + ErrorLabels, |
| 12 | + ScalarToolLabels, |
9 | 13 | ScanPathResult,
|
| 14 | + ServerSignature, |
| 15 | + ToolAnnotationsWithLabels, |
10 | 16 | VerifyServerRequest,
|
11 | 17 | VerifyServerResponse,
|
12 | 18 | entity_to_tool,
|
13 | 19 | )
|
14 | 20 |
|
15 |
| -if TYPE_CHECKING: |
16 |
| - from mcp.types import Tool |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
17 | 23 |
|
18 | 24 | POLICY_PATH = "src/mcp_scan/policy.gr"
|
19 | 25 |
|
20 | 26 |
|
| 27 | +async def tool_get_labels(tool: Tool, base_url: str) -> Tool: |
| 28 | + """ |
| 29 | + Get labels from the tool and add them to the tool's metadata. |
| 30 | + """ |
| 31 | + logger.debug("Getting labels for tool: %s", tool.name) |
| 32 | + output_tool = tool.model_copy(deep=True) |
| 33 | + url = base_url[:-1] if base_url.endswith("/") else base_url |
| 34 | + url = url + "/api/v1/public/labels" |
| 35 | + headers = {"Content-Type": "application/json"} |
| 36 | + try: |
| 37 | + async with aiohttp.ClientSession() as session: |
| 38 | + async with session.post(url, headers=headers, data=tool.model_dump_json()) as response: |
| 39 | + if response.status == 200: |
| 40 | + scalar_tool_labels = ScalarToolLabels.model_validate_json(await response.read()) |
| 41 | + else: |
| 42 | + raise Exception(f"Error: {response.status} - {await response.text()}") |
| 43 | + except Exception as e: |
| 44 | + output_tool.annotations = ToolAnnotationsWithLabels( |
| 45 | + **output_tool.annotations.model_dump() if output_tool.annotations else {}, |
| 46 | + labels=ErrorLabels(error=str(e) if isinstance(e, Exception) else "Unknown error"), |
| 47 | + ) |
| 48 | + return output_tool |
| 49 | + output_tool.annotations = ToolAnnotationsWithLabels( |
| 50 | + **output_tool.annotations.model_dump() if output_tool.annotations else {}, |
| 51 | + labels=scalar_tool_labels, |
| 52 | + ) |
| 53 | + return output_tool |
| 54 | + |
| 55 | + |
| 56 | +async def server_get_labels(server: ServerSignature, base_url: str) -> ServerSignature: |
| 57 | + """ |
| 58 | + Get labels from the server and add them to the server's metadata. |
| 59 | + """ |
| 60 | + logger.debug("Getting labels for server: %s", server.metadata.serverInfo.name) |
| 61 | + output_server = server.model_copy(deep=True) |
| 62 | + annotated_tools = [tool_get_labels(tool, base_url) for tool in output_server.tools] |
| 63 | + output_server.tools = await asyncio.gather(*annotated_tools) |
| 64 | + return output_server |
| 65 | + |
| 66 | + |
| 67 | +async def scan_path_get_labels(servers: list[ServerSignature | None], base_url: str) -> list[ServerSignature | None]: |
| 68 | + """ |
| 69 | + Get labels for all servers in the scan path. |
| 70 | + """ |
| 71 | + logger.debug(f"Getting labels for {len(servers)} servers") |
| 72 | + |
| 73 | + async def server_get_labels_or_skip(server: ServerSignature | None) -> ServerSignature | None: |
| 74 | + if server is None: |
| 75 | + return None |
| 76 | + return await server_get_labels(server, base_url) |
| 77 | + |
| 78 | + return await asyncio.gather(*[server_get_labels_or_skip(server) for server in servers]) |
| 79 | + |
| 80 | + |
21 | 81 | async def verify_scan_path_public_api(scan_path: ScanPathResult, base_url: str) -> ScanPathResult:
|
22 | 82 | output_path = scan_path.clone()
|
23 | 83 | url = base_url[:-1] if base_url.endswith("/") else base_url
|
@@ -99,3 +159,20 @@ async def verify_scan_path(scan_path: ScanPathResult, base_url: str, run_locally
|
99 | 159 | return await verify_scan_path_locally(scan_path)
|
100 | 160 | else:
|
101 | 161 | return await verify_scan_path_public_api(scan_path, base_url)
|
| 162 | + |
| 163 | + |
| 164 | +async def verify_scan_path_and_labels(scan_path: ScanPathResult, base_url: str, run_locally: bool) -> ScanPathResult: |
| 165 | + """ |
| 166 | + Verify the scan path and get labels for all servers in the scan path. |
| 167 | + Runs concurrently to speed up the process. |
| 168 | + """ |
| 169 | + verified_scan_path_task = verify_scan_path(scan_path, base_url, run_locally) |
| 170 | + signatures_with_labels_task = scan_path_get_labels([server.signature for server in scan_path.servers], base_url) |
| 171 | + verified_scan_path, signatures_with_labels = await asyncio.gather( |
| 172 | + verified_scan_path_task, |
| 173 | + signatures_with_labels_task, |
| 174 | + ) |
| 175 | + logger.debug("Verified scan path and labels retrieved successfully") |
| 176 | + for server, signature in zip(verified_scan_path.servers, signatures_with_labels, strict=False): |
| 177 | + server.signature = signature |
| 178 | + return verified_scan_path |
0 commit comments