Skip to content

Commit c3aa794

Browse files
committed
feat: add query for labels
1 parent 855b126 commit c3aa794

File tree

5 files changed

+171
-9
lines changed

5 files changed

+171
-9
lines changed

src/mcp_scan/MCPScanner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .mcp_client import check_server_with_timeout, scan_mcp_config_file
1111
from .StorageFile import StorageFile
12-
from .verify_api import verify_scan_path
12+
from .verify_api import verify_scan_path_and_labels
1313

1414
# Set up logger for this module
1515
logger = logging.getLogger(__name__)
@@ -190,7 +190,9 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu
190190
logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name)
191191
path_result.servers[i] = await self.scan_server(server, inspect_only)
192192
logger.debug("Verifying server path: %s", path)
193-
path_result = await verify_scan_path(path_result, base_url=self.base_url, run_locally=self.local_only)
193+
path_result = await verify_scan_path_and_labels(
194+
path_result, base_url=self.base_url, run_locally=self.local_only
195+
)
194196
await self.emit("path_scanned", path_result)
195197
return path_result
196198

src/mcp_scan/models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from itertools import chain
44
from typing import Any, Literal, TypeAlias
55

6-
from mcp.types import InitializeResult, Prompt, Resource, Tool
6+
from mcp.types import InitializeResult, Prompt, Resource, Tool, ToolAnnotations
77
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_serializer, field_validator
88

99
Entity: TypeAlias = Prompt | Resource | Tool
@@ -272,3 +272,19 @@ def entity_to_tool(
272272
)
273273
else:
274274
raise ValueError(f"Unknown entity type: {type(entity)}")
275+
276+
277+
class ScalarToolLabels(BaseModel):
278+
is_public_sink: int | float
279+
destructive: int | float
280+
untrusted_output: int | float
281+
private_data: int | float
282+
prompt_injection: int | float
283+
284+
285+
class ErrorLabels(BaseModel):
286+
error: str
287+
288+
289+
class ToolAnnotationsWithLabels(ToolAnnotations):
290+
labels: ScalarToolLabels | ErrorLabels

src/mcp_scan/printer.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,22 @@
22
import textwrap
33

44
import rich
5+
from mcp.types import Tool
56
from rich.text import Text
67
from rich.traceback import Traceback as rTraceback
78
from rich.tree import Tree
89

9-
from .models import Entity, EntityScanResult, ScanError, ScanPathResult, entity_type_to_str, hash_entity
10+
from .models import (
11+
Entity,
12+
EntityScanResult,
13+
ErrorLabels,
14+
ScalarToolLabels,
15+
ScanError,
16+
ScanPathResult,
17+
ToolAnnotationsWithLabels,
18+
entity_type_to_str,
19+
hash_entity,
20+
)
1021

1122

1223
def format_exception(e: Exception | None) -> tuple[str, rTraceback | None]:
@@ -51,6 +62,23 @@ def append_status(status: str, new_status: str) -> str:
5162
return f"{new_status}, {status}"
5263

5364

65+
def format_scalar_labels(labels: ScalarToolLabels) -> str:
66+
"""
67+
Format scalar labels into a string.
68+
"""
69+
label_parts = []
70+
if labels.is_public_sink > 0:
71+
label_parts.append(f"[gold1]Public sink: {str(labels.is_public_sink).rstrip('.0')}[/gold1]")
72+
if labels.destructive > 0:
73+
label_parts.append(f"[gold1]Destructive: {str(labels.destructive).rstrip('.0')}[/gold1]")
74+
if labels.untrusted_output > 0:
75+
label_parts.append(f"[gold1]Untrusted output: {str(labels.untrusted_output).rstrip('.0')}[/gold1]")
76+
if labels.private_data > 0:
77+
label_parts.append(f"[gold1]Private data: {str(labels.private_data).rstrip('.0')}[/gold1]")
78+
79+
return " | ".join(label_parts)
80+
81+
5482
def format_entity_line(entity: Entity, result: EntityScanResult | None = None) -> Text:
5583
# is_verified = verified.value
5684
# if is_verified is not None and changed.value is not None:
@@ -60,7 +88,7 @@ def format_entity_line(entity: Entity, result: EntityScanResult | None = None) -
6088
include_description = True
6189
if result is not None:
6290
is_verified = result.verified
63-
status = result.status or ""
91+
status = "| " + result.status if result.status else ""
6492
if result.changed is not None and result.changed:
6593
is_verified = False
6694
status = append_status(status, "[bold]changed since previous scan[/bold]")
@@ -82,7 +110,19 @@ def format_entity_line(entity: Entity, result: EntityScanResult | None = None) -
82110
type = entity_type_to_str(entity)
83111
type = type + " " * (len("resource") - len(type))
84112

85-
text = f"{type} {color}[bold]{name}[/bold] {icon} {status}"
113+
# labels
114+
labels = ""
115+
if (
116+
isinstance(entity, Tool)
117+
and entity.annotations is not None
118+
and isinstance(entity.annotations, ToolAnnotationsWithLabels)
119+
):
120+
if isinstance(entity.annotations.labels, ScalarToolLabels):
121+
labels = format_scalar_labels(entity.annotations.labels)
122+
elif isinstance(entity.annotations.labels, ErrorLabels):
123+
labels = f"[gray62]Error in labels computation: {entity.annotations.labels.error}[/gray62]"
124+
125+
text = f"{type} {color}[bold]{name}[/bold] {icon} {labels} {status}"
86126

87127
if include_description:
88128
if hasattr(entity, "description") and entity.description is not None:

src/mcp_scan/verify_api.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,83 @@
11
import ast
2-
from typing import TYPE_CHECKING
2+
import asyncio
3+
import logging
34

45
import aiohttp
56
from invariant.analyzer.policy import LocalPolicy
7+
from mcp.types import Tool
68

79
from .models import (
810
EntityScanResult,
11+
ErrorLabels,
12+
ScalarToolLabels,
913
ScanPathResult,
14+
ServerSignature,
15+
ToolAnnotationsWithLabels,
1016
VerifyServerRequest,
1117
VerifyServerResponse,
1218
entity_to_tool,
1319
)
1420

15-
if TYPE_CHECKING:
16-
from mcp.types import Tool
21+
logger = logging.getLogger(__name__)
22+
1723

1824
POLICY_PATH = "src/mcp_scan/policy.gr"
1925

2026

27+
async def tool_get_labels(tool: Tool, base_url: str) -> Tool:
28+
"""
29+
Get labels from the tool and add them to the tool's metadata.
30+
"""
31+
logger.debug("Getting labels for tool: %s", tool.name)
32+
output_tool = tool.model_copy(deep=True)
33+
url = base_url[:-1] if base_url.endswith("/") else base_url
34+
url = url + "/api/v1/public/labels"
35+
headers = {"Content-Type": "application/json"}
36+
try:
37+
async with aiohttp.ClientSession() as session:
38+
async with session.post(url, headers=headers, data=tool.model_dump_json()) as response:
39+
if response.status == 200:
40+
scalar_tool_labels = ScalarToolLabels.model_validate_json(await response.read())
41+
else:
42+
raise Exception(f"Error: {response.status} - {await response.text()}")
43+
except Exception as e:
44+
output_tool.annotations = ToolAnnotationsWithLabels(
45+
**output_tool.annotations.model_dump() if output_tool.annotations else {},
46+
labels=ErrorLabels(error=str(e) if isinstance(e, Exception) else "Unknown error"),
47+
)
48+
return output_tool
49+
output_tool.annotations = ToolAnnotationsWithLabels(
50+
**output_tool.annotations.model_dump() if output_tool.annotations else {},
51+
labels=scalar_tool_labels,
52+
)
53+
return output_tool
54+
55+
56+
async def server_get_labels(server: ServerSignature, base_url: str) -> ServerSignature:
57+
"""
58+
Get labels from the server and add them to the server's metadata.
59+
"""
60+
logger.debug("Getting labels for server: %s", server.metadata.serverInfo.name)
61+
output_server = server.model_copy(deep=True)
62+
annotated_tools = [tool_get_labels(tool, base_url) for tool in output_server.tools]
63+
output_server.tools = await asyncio.gather(*annotated_tools)
64+
return output_server
65+
66+
67+
async def scan_path_get_labels(servers: list[ServerSignature | None], base_url: str) -> list[ServerSignature | None]:
68+
"""
69+
Get labels for all servers in the scan path.
70+
"""
71+
logger.debug(f"Getting labels for {len(servers)} servers")
72+
73+
async def server_get_labels_or_skip(server: ServerSignature | None) -> ServerSignature | None:
74+
if server is None:
75+
return None
76+
return await server_get_labels(server, base_url)
77+
78+
return await asyncio.gather(*[server_get_labels_or_skip(server) for server in servers])
79+
80+
2181
async def verify_scan_path_public_api(scan_path: ScanPathResult, base_url: str) -> ScanPathResult:
2282
output_path = scan_path.clone()
2383
url = base_url[:-1] if base_url.endswith("/") else base_url
@@ -99,3 +159,20 @@ async def verify_scan_path(scan_path: ScanPathResult, base_url: str, run_locally
99159
return await verify_scan_path_locally(scan_path)
100160
else:
101161
return await verify_scan_path_public_api(scan_path, base_url)
162+
163+
164+
async def verify_scan_path_and_labels(scan_path: ScanPathResult, base_url: str, run_locally: bool) -> ScanPathResult:
165+
"""
166+
Verify the scan path and get labels for all servers in the scan path.
167+
Runs concurrently to speed up the process.
168+
"""
169+
verified_scan_path_task = verify_scan_path(scan_path, base_url, run_locally)
170+
signatures_with_labels_task = scan_path_get_labels([server.signature for server in scan_path.servers], base_url)
171+
verified_scan_path, signatures_with_labels = await asyncio.gather(
172+
verified_scan_path_task,
173+
signatures_with_labels_task,
174+
)
175+
logger.debug("Verified scan path and labels retrieved successfully")
176+
for server, signature in zip(verified_scan_path.servers, signatures_with_labels, strict=False):
177+
server.signature = signature
178+
return verified_scan_path

tests/mcp_servers/math_server.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,32 @@ def divide(a: int, b: int) -> int:
3434
return a // b
3535

3636

37+
@mcp.resource(uri="prime_numbers://{n}")
38+
def prime_numbers(n: int) -> str:
39+
"""Lists prime numbers smaller than or equal to n."""
40+
if n < 2:
41+
return "No prime numbers less than 2"
42+
43+
primes = []
44+
for num in range(2, n + 1):
45+
if all(num % i != 0 for i in range(2, int(num**0.5) + 1)):
46+
primes.append(num)
47+
48+
return f"[{', '.join(map(str, primes))}]"
49+
50+
51+
@mcp.prompt()
52+
def math_prompt() -> str:
53+
"""Prompt for math operations."""
54+
return """
55+
You can perform the following operations:
56+
1. Add two numbers: `add(3, 5)`
57+
2. Subtract two numbers: `subtract(10, 4)`
58+
3. Multiply two numbers: `multiply(2, 6)`
59+
4. Divide two numbers: `divide(8, 2)`
60+
You can also use the resource endpoint `prime_numbers://{n}` to get prime numbers up to n.
61+
"""
62+
63+
3764
if __name__ == "__main__":
3865
mcp.run()

0 commit comments

Comments
 (0)