Skip to content

Feat/toxic flow tux #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 18, 2025
Merged
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
[Documentation](https://explorer.invariantlabs.ai/docs/mcp-scan) | [Support Discord](https://discord.gg/dZuZfhKnJ4)


MCP-Scan is a security scanning tool to both statically and dynamically scan and monitor your MCP connections. It checks them for common security vulnerabilities like [prompt injections](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), [tool poisoning](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks) and [cross-origin escalations](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks).
MCP-Scan is a security scanning tool to both statically and dynamically scan and monitor your MCP connections. It checks them for common security vulnerabilities like [prompt injections](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), [tool poisoning](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks) and [toxic flows](https://invariantlabs.ai/blog/mcp-github-vulnerability).

It operates in two main modes which can be used jointly or separately:

1. `mcp-scan scan` statically scans all your installed servers for malicious tool descriptions and tools (e.g. [tool poisoning attacks](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), cross-origin escalation, rug pull attacks).
1. `mcp-scan scan` statically scans all your installed servers for malicious tool descriptions and tools (e.g. [tool poisoning attacks](https://invariantlabs.ai/blog/mcp-security-notification-tool-poisoning-attacks), cross-origin escalation, rug pull attacks, toxic flows).

[Quickstart →](#server-scanning).

Expand Down Expand Up @@ -93,7 +93,6 @@ MCP-Scan `scan` searches through your configuration files to find MCP server con
It then scans tool descriptions, both with local checks and by invoking Invariant Guardrailing via an API. For this, tool names and descriptions are shared with invariantlabs.ai. By using MCP-Scan, you agree to the invariantlabs.ai [terms of use](https://explorer.invariantlabs.ai/terms) and [privacy policy](https://invariantlabs.ai/privacy-policy).

Invariant Labs is collecting data for security research purposes (only about tool descriptions and how they change over time, not your user data). Don't use MCP-scan if you don't want to share your tools.
You can run MCP-scan locally by using the `--local-only` flag. This will only run local checks and will not invoke the Invariant Guardrailing API, however it will not provide as accurate results as it just runs a local LLM-based policy check. This option requires an `OPENAI_API_KEY` environment variable to be set.

MCP-scan does not store or log any usage data, i.e. the contents and results of your MCP tool calls.

Expand All @@ -120,6 +119,7 @@ These options are available for all commands:
--base-url URL Base URL for the verification server
--verbose Enable detailed logging output
--print-errors Show error details and tracebacks
--full-toxic-flows Show all tools that could take part in toxic flow. By default only the top 3 are shown.
--json Output results in JSON format instead of rich text
```

Expand All @@ -138,7 +138,6 @@ Options:
--checks-per-server NUM Number of checks to perform on each server (default: 1)
--server-timeout SECONDS Seconds to wait before timing out server connections (default: 10)
--suppress-mcpserver-io BOOL Suppress stdout/stderr from MCP servers (default: True)
--local-only BOOL Only run verification locally. Does not run all checks, results will be less accurate (default: False)
```

#### proxy
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mcp-scan"
version = "0.2.3"
version = "0.3.0"
description = "MCP Scan tool"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
77 changes: 40 additions & 37 deletions src/mcp_scan/MCPScanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from collections.abc import Callable
from typing import Any

from mcp_scan.models import ScanError, ScanPathResult, ServerScanResult
from mcp_scan.models import Issue, ScanError, ScanPathResult, ServerScanResult

from .mcp_client import check_server_with_timeout, scan_mcp_config_file
from .StorageFile import StorageFile
from .verify_api import verify_scan_path
from .verify_api import analyze_scan_path

# Set up logger for this module
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,7 +56,6 @@ def __init__(
storage_file: str = "~/.mcp-scan",
server_timeout: int = 10,
suppress_mcpserver_io: bool = True,
local_only: bool = False,
**kwargs: Any,
):
logger.info("Initializing MCPScanner")
Expand All @@ -70,7 +69,6 @@ def __init__(
self.server_timeout = server_timeout
self.suppress_mcpserver_io = suppress_mcpserver_io
self.context_manager = None
self.local_only = local_only
logger.debug(
"MCPScanner initialized with timeout: %d, checks_per_server: %d", server_timeout, checks_per_server
)
Expand Down Expand Up @@ -125,31 +123,38 @@ async def get_servers_from_path(self, path: str) -> ScanPathResult:
result.error = ScanError(message=error_msg, exception=e)
return result

async def check_server_changed(self, server: ServerScanResult) -> ServerScanResult:
logger.debug("Checking for changes in server: %s %s", server.name, server.result)
output_server = server.clone()
for i, (entity, entity_result) in enumerate(server.entities_with_result):
if entity_result is None:
continue
c, messages = self.storage_file.check_and_update(server.name or "", entity, entity_result.verified)
output_server.result[i].changed = c # type: ignore
if c:
logger.info("Entity %s in server %s has changed", entity.name, server.name)
output_server.result[i].messages.extend(messages) # type: ignore
return output_server

async def check_whitelist(self, server: ServerScanResult) -> ServerScanResult:
logger.debug("Checking whitelist for server: %s", server.name)
output_server = server.clone()
for i, (entity, entity_result) in enumerate(server.entities_with_result):
if entity_result is None:
continue
if self.storage_file.is_whitelisted(entity):
logger.debug("Entity %s is whitelisted", entity.name)
output_server.result[i].whitelisted = True # type: ignore
else:
output_server.result[i].whitelisted = False # type: ignore
return output_server
def check_server_changed(self, path_result: ScanPathResult) -> list[Issue]:
logger.debug("Checking server changed: %s", path_result.path)
issues: list[Issue] = []
for server_idx, server in enumerate(path_result.servers):
logger.debug(
"Checking for changes in server %d/%d: %s", server_idx + 1, len(path_result.servers), server.name
)
for entity_idx, entity in enumerate(server.entities):
c, messages = self.storage_file.check_and_update(server.name or "", entity)
if c:
logger.info("Entity %s in server %s has changed", entity.name, server.name)
issues.append(
Issue(
code="W003",
message="Entity has changed. " + ", ".join(messages),
reference=(server_idx, entity_idx),
)
)
return issues

def check_whitelist(self, path_result: ScanPathResult) -> list[Issue]:
logger.debug("Checking whitelist for path: %s", path_result.path)
issues: list[Issue] = []
for server_idx, server in enumerate(path_result.servers):
for entity_idx, entity in enumerate(server.entities):
if self.storage_file.is_whitelisted(entity):
issues.append(
Issue(
code="X002", message="This entity has been whitelisted", reference=(server_idx, entity_idx)
)
)
return issues

async def emit(self, signal: str, data: Any):
logger.debug("Emitting signal: %s", signal)
Expand All @@ -170,12 +175,6 @@ async def scan_server(self, server: ServerScanResult, inspect_only: bool = False
len(result.signature.resources),
len(result.signature.tools),
)

if not inspect_only:
logger.debug("Checking if server has changed: %s", server.name)
result = await self.check_server_changed(result)
logger.debug("Checking whitelist for server: %s", server.name)
result = await self.check_whitelist(result)
except Exception as e:
error_msg = "could not start server"
logger.exception("%s: %s", error_msg, server.name)
Expand All @@ -189,8 +188,12 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu
for i, server in enumerate(path_result.servers):
logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name)
path_result.servers[i] = await self.scan_server(server, inspect_only)
logger.debug("Verifying server path: %s", path)
path_result = await verify_scan_path(path_result, base_url=self.base_url, run_locally=self.local_only)
logger.debug(f"Check whitelisted {path}, {path is None}")
path_result.issues += self.check_whitelist(path_result)
logger.debug(f"Check changed: {path}, {path is None}")
path_result.issues += self.check_server_changed(path_result)
logger.debug(f"Verifying server path: {path}, {path is None}")
path_result = await analyze_scan_path(path_result, base_url=self.base_url)
await self.emit("path_scanned", path_result)
return path_result

Expand Down
4 changes: 1 addition & 3 deletions src/mcp_scan/StorageFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,15 @@ def reset_whitelist(self) -> None:
self.whitelist = {}
self.save()

def check_and_update(self, server_name: str, entity: Entity, verified: bool | None) -> tuple[bool, list[str]]:
def check_and_update(self, server_name: str, entity: Entity) -> tuple[bool, list[str]]:
logger.debug("Checking entity: %s in server: %s", entity.name, server_name)
entity_type = entity_type_to_str(entity)
key = f"{server_name}.{entity_type}.{entity.name}"
hash = hash_entity(entity)
logger.debug("Entity key: %s, hash: %s", key, hash)

new_data = ScannedEntity(
hash=hash,
type=entity_type,
verified=verified,
timestamp=datetime.now(),
description=entity.description,
)
Expand Down
22 changes: 5 additions & 17 deletions src/mcp_scan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def main():
metavar="NUM",
)
scan_parser.add_argument(
"--local-only",
"--full-toxic-flows",
default=False,
action="store_true",
help="Only run verification locally. Does not run all checks, results will be less accurate.",
help="Show all tools in the toxic flows, by default only the first 3 are shown.",
)

# INSPECT command
Expand Down Expand Up @@ -457,20 +457,6 @@ def server(on_exit=None):
elif args.command == "uninstall":
asyncio.run(uninstall())
sys.exit(0)
elif args.command == "whitelist":
if args.reset:
MCPScanner(**vars(args)).reset_whitelist()
sys.exit(0)
elif all(x is None for x in [args.name, args.hash]): # no args
MCPScanner(**vars(args)).print_whitelist()
sys.exit(0)
elif all(x is not None for x in [args.name, args.hash]):
MCPScanner(**vars(args)).whitelist(args.name, args.hash, args.local_only)
MCPScanner(**vars(args)).print_whitelist()
sys.exit(0)
else:
rich.print("[bold red]Please provide a name and hash.[/bold red]")
sys.exit(1)
elif args.command == "scan" or args.command is None: # default to scan
asyncio.run(run_scan_inspect(args=args))
sys.exit(0)
Expand Down Expand Up @@ -499,11 +485,13 @@ async def run_scan_inspect(mode="scan", args=None):
result = await scanner.scan()
elif mode == "inspect":
result = await scanner.inspect()
else:
raise ValueError(f"Unknown mode: {mode}, expected 'scan' or 'inspect'")
if args.json:
result = {r.path: r.model_dump() for r in result}
print(json.dumps(result, indent=2))
else:
print_scan_result(result)
print_scan_result(result, args.print_errors, args.full_toxic_flows)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions src/mcp_scan/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,22 @@ async def _check_server(verbose: bool) -> ServerSignature:
if isinstance(server_config, StdioServer) or meta.capabilities.prompts:
logger.debug("Fetching prompts")
try:
prompts = (await session.list_prompts()).prompts
prompts += (await session.list_prompts()).prompts
logger.debug("Found %d prompts", len(prompts))
except Exception:
logger.exception("Failed to list prompts")

if isinstance(server_config, StdioServer) or meta.capabilities.resources:
logger.debug("Fetching resources")
try:
resources = (await session.list_resources()).resources
resources += (await session.list_resources()).resources
logger.debug("Found %d resources", len(resources))
except Exception:
logger.exception("Failed to list resources")
if isinstance(server_config, StdioServer) or meta.capabilities.tools:
logger.debug("Fetching tools")
try:
tools = (await session.list_tools()).tools
tools += (await session.list_tools()).tools
logger.debug("Found %d tools", len(tools))
except Exception:
logger.exception("Failed to list tools")
Expand Down
53 changes: 29 additions & 24 deletions src/mcp_scan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class ScannedEntity(BaseModel):
model_config = ConfigDict()
hash: str
type: str
verified: bool | None
timestamp: datetime
description: str | None = None

Expand Down Expand Up @@ -147,13 +146,17 @@ def clone(self) -> "ScanError":
)


class EntityScanResult(BaseModel):
model_config = ConfigDict()
verified: bool | None = None
changed: bool | None = None
whitelisted: bool | None = None
status: str | None = None
messages: list[str] = []
class Issue(BaseModel):
code: str
message: str
reference: tuple[int, int] | None = Field(
default=None,
description="The index of the tool the issue references. None if it is global",
)
extra_data: dict[str, Any] | None = Field(
default=None,
description="Extra data to provide more context about the issue.",
)


class ServerSignature(BaseModel):
Expand All @@ -167,20 +170,15 @@ def entities(self) -> list[Entity]:
return self.prompts + self.resources + self.tools


class VerifyServerResponse(RootModel):
root: list[list[EntityScanResult]]


class VerifyServerRequest(RootModel):
root: list[ServerSignature]
class VerifyServerRequest(RootModel[list[ServerSignature | None]]):
pass


class ServerScanResult(BaseModel):
model_config = ConfigDict()
name: str | None = None
server: SSEServer | StdioServer | StreamableHTTPServer
signature: ServerSignature | None = None
result: list[EntityScanResult] | None = None
error: ScanError | None = None

@property
Expand All @@ -194,13 +192,6 @@ def entities(self) -> list[Entity]:
def is_verified(self) -> bool:
return self.result is not None

@property
def entities_with_result(self) -> list[tuple[Entity, EntityScanResult | None]]:
if self.result is not None:
return list(zip(self.entities, self.result, strict=False))
else:
return [(entity, None) for entity in self.entities]

def clone(self) -> "ServerScanResult":
"""
Create a copy of the ServerScanResult instance. This is not the same as `model_copy(deep=True)`, because it does not
Expand All @@ -210,7 +201,6 @@ def clone(self) -> "ServerScanResult":
name=self.name,
server=self.server.model_copy(deep=True),
signature=self.signature.model_copy(deep=True) if self.signature else None,
result=[result.model_copy(deep=True) for result in self.result] if self.result else None,
error=self.error.clone() if self.error else None,
)
return output
Expand All @@ -219,7 +209,8 @@ def clone(self) -> "ServerScanResult":
class ScanPathResult(BaseModel):
model_config = ConfigDict()
path: str
servers: list[ServerScanResult] = []
servers: list[ServerScanResult] = Field(default_factory=list)
issues: list[Issue] = Field(default_factory=list)
error: ScanError | None = None

@property
Expand All @@ -234,6 +225,7 @@ def clone(self) -> "ScanPathResult":
output = ScanPathResult(
path=self.path,
servers=[server.clone() for server in self.servers],
issues=[issue.model_copy(deep=True) for issue in self.issues],
error=self.error.clone() if self.error else None,
)
return output
Expand Down Expand Up @@ -272,3 +264,16 @@ def entity_to_tool(
)
else:
raise ValueError(f"Unknown entity type: {type(entity)}")


class ToolReferenceWithLabel(BaseModel):
reference: tuple[int, int]
label_value: float


class ToxicFlowExtraData(RootModel[dict[str, list[ToolReferenceWithLabel]]]):
pass


class AnalysisServerResponse(BaseModel):
issues: list[Issue]
Loading
Loading