Skip to content

Commit 2980a50

Browse files
authored
feat: support passing 'stateless' and 'cwd' arguments (#62)
1. Add support for --stateless parameter configuration 2. Add support for --cwd parameter passing working directory to mcp stdio server 3. Use StreamableHTTPSessionManager from the latest python-mcp-sdk release to manage sessions, simplifying code 4. Optimize test cases
1 parent 8fee3d9 commit 2980a50

File tree

4 files changed

+113
-105
lines changed

4 files changed

+113
-105
lines changed

README.md

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,18 @@ separator.
114114

115115
Arguments
116116

117-
| Name | Required | Description | Example |
118-
|---------------------------|----------------------------|------------------------------------------------------------------|-----------------------|
119-
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
120-
| `--port` | No, random available | The MCP server port to listen on | 8080 |
121-
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
122-
| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR |
123-
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
124-
| `--allow-origin` | No | Pass through all environment variables when spawning the server | --allow-cors "\*" |
125-
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
126-
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
117+
| Name | Required | Description | Example |
118+
|---------------------------|----------------------------|---------------------------------------------------------------------------------------------|-----------------------|
119+
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
120+
| `--port` | No, random available | The MCP server port to listen on | 8080 |
121+
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
122+
| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR |
123+
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
124+
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
125+
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" |
126+
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
127+
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
128+
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
127129

128130
### 2.2 Example usage
129131

@@ -147,7 +149,8 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch
147149
mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
148150
```
149151

150-
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or `http://127.0.0.1:8080/mcp/` via StreamableHttp
152+
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or
153+
`http://127.0.0.1:8080/mcp/` via StreamableHttp
151154

152155
## Installation
153156

src/mcp_proxy/__main__.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def main() -> None:
7777
help="Environment variables used when spawning the server. Can be used multiple times.",
7878
default=[],
7979
)
80+
stdio_client_options.add_argument(
81+
"--cwd",
82+
default=None,
83+
help="The working directory to use when spawning the process.",
84+
)
8085
stdio_client_options.add_argument(
8186
"--pass-environment",
8287
action=argparse.BooleanOptionalAction,
@@ -90,30 +95,36 @@ def main() -> None:
9095
default=False,
9196
)
9297

93-
sse_server_group = parser.add_argument_group("SSE server options")
94-
sse_server_group.add_argument(
98+
mcp_server_group = parser.add_argument_group("SSE server options")
99+
mcp_server_group.add_argument(
95100
"--port",
96101
type=int,
97102
default=None,
98103
help="Port to expose an SSE server on. Default is a random port",
99104
)
100-
sse_server_group.add_argument(
105+
mcp_server_group.add_argument(
101106
"--host",
102107
default=None,
103108
help="Host to expose an SSE server on. Default is 127.0.0.1",
104109
)
105-
sse_server_group.add_argument(
110+
mcp_server_group.add_argument(
111+
"--stateless",
112+
action=argparse.BooleanOptionalAction,
113+
help="Enable stateless mode for streamable http transports. Default is False",
114+
default=False,
115+
)
116+
mcp_server_group.add_argument(
106117
"--sse-port",
107118
type=int,
108119
default=0,
109120
help="Port to expose an SSE server on. Default is a random port",
110121
)
111-
sse_server_group.add_argument(
122+
mcp_server_group.add_argument(
112123
"--sse-host",
113124
default="127.0.0.1",
114125
help="Host to expose an SSE server on. Default is 127.0.0.1",
115126
)
116-
sse_server_group.add_argument(
127+
mcp_server_group.add_argument(
117128
"--allow-origin",
118129
nargs="+",
119130
default=[],
@@ -161,11 +172,13 @@ def main() -> None:
161172
command=args.command_or_url,
162173
args=args.args,
163174
env=env,
175+
cwd=args.cwd if args.cwd else None,
164176
)
165177

166178
mcp_settings = MCPServerSettings(
167179
bind_host=args.host if args.host is not None else args.sse_host,
168180
port=args.port if args.port is not None else args.sse_port,
181+
stateless=args.stateless,
169182
allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None,
170183
log_level="DEBUG" if args.debug else "INFO",
171184
)

src/mcp_proxy/mcp_server.py

Lines changed: 23 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,14 @@
55
from collections.abc import AsyncIterator
66
from dataclasses import dataclass
77
from datetime import datetime, timezone
8-
from http import HTTPStatus
9-
from typing import Any, Literal
10-
from uuid import uuid4
8+
from typing import Literal
119

12-
import anyio
1310
import uvicorn
14-
from anyio.abc import TaskStatus
1511
from mcp.client.session import ClientSession
1612
from mcp.client.stdio import StdioServerParameters, stdio_client
1713
from mcp.server import Server
1814
from mcp.server.sse import SseServerTransport
19-
from mcp.server.streamable_http import StreamableHTTPServerTransport
15+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2016
from starlette.applications import Starlette
2117
from starlette.middleware import Middleware
2218
from starlette.middleware.cors import CORSMiddleware
@@ -28,28 +24,6 @@
2824
from .proxy_server import create_proxy_server
2925

3026
logger = logging.getLogger(__name__)
31-
# Global task group that will be initialized in the lifespan
32-
task_group = None
33-
34-
MCP_SESSION_ID_HEADER = "mcp-session-id"
35-
36-
37-
@contextlib.asynccontextmanager
38-
async def lifespan(_: Starlette) -> AsyncIterator[None]:
39-
"""Application lifespan context manager for managing task group."""
40-
global task_group # noqa: PLW0603
41-
42-
async with anyio.create_task_group() as tg:
43-
task_group = tg
44-
logger.info("Application started, task group initialized!")
45-
try:
46-
yield
47-
finally:
48-
logger.info("Application shutting down, cleaning up resources...")
49-
if task_group:
50-
tg.cancel_scope.cancel()
51-
task_group = None
52-
logger.info("Resources cleaned up successfully.")
5327

5428

5529
@dataclass
@@ -58,13 +32,15 @@ class MCPServerSettings:
5832

5933
bind_host: str
6034
port: int
35+
stateless: bool = False
6136
allow_origins: list[str] | None = None
6237
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
6338

6439

65-
def create_starlette_app( # noqa: C901, Refactor required for complexity
40+
def create_starlette_app(
6641
mcp_server: Server[object],
6742
*,
43+
stateless: bool = False,
6844
allow_origins: list[str] | None = None,
6945
debug: bool = False,
7046
) -> Starlette:
@@ -97,57 +73,17 @@ async def handle_sse(request: Request) -> None:
9773
mcp_server.create_initialization_options(),
9874
)
9975

100-
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/5d8eaf77be00dbd9b33a7fe1e38cb0da77e49401/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py
101-
# We need to store the server instances between requests
102-
server_instances: dict[str, Any] = {}
103-
# Lock to prevent race conditions when creating new sessions
104-
session_creation_lock = anyio.Lock()
76+
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py
77+
http = StreamableHTTPSessionManager(
78+
app=mcp_server,
79+
event_store=None,
80+
json_response=True,
81+
stateless=stateless,
82+
)
10583

10684
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
10785
_update_mcp_activity()
108-
request = Request(scope, receive)
109-
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
110-
if request_mcp_session_id is not None and request_mcp_session_id in server_instances:
111-
transport = server_instances[request_mcp_session_id]
112-
logger.debug("Session already exists, handling request directly")
113-
await transport.handle_request(scope, receive, send)
114-
elif request_mcp_session_id is None:
115-
# try to establish new session
116-
logger.debug("Creating new transport")
117-
# Use lock to prevent race conditions when creating new sessions
118-
async with session_creation_lock:
119-
new_session_id = uuid4().hex
120-
http_transport = StreamableHTTPServerTransport(
121-
mcp_session_id=new_session_id,
122-
is_json_response_enabled=True,
123-
)
124-
server_instances[new_session_id] = http_transport
125-
logger.info("Created new transport with session ID: %s", new_session_id)
126-
127-
async def run_server(task_status: TaskStatus[Any] | None = None) -> None:
128-
async with http_transport.connect() as streams:
129-
read_stream, write_stream = streams
130-
if task_status:
131-
task_status.started()
132-
await mcp_server.run(
133-
read_stream,
134-
write_stream,
135-
mcp_server.create_initialization_options(),
136-
)
137-
138-
if not task_group:
139-
raise RuntimeError("Task group is not initialized")
140-
141-
await task_group.start(run_server)
142-
143-
# Handle the HTTP request and return the response
144-
await http_transport.handle_request(scope, receive, send)
145-
else:
146-
response = Response(
147-
"Bad Request: No valid session ID provided",
148-
status_code=HTTPStatus.BAD_REQUEST,
149-
)
150-
await response(scope, receive, send)
86+
await http.handle_request(scope, receive, send)
15187

15288
async def handle_status(_: Request) -> Response:
15389
"""Health check and service usage monitoring endpoint.
@@ -159,6 +95,16 @@ async def handle_status(_: Request) -> Response:
15995
"""
16096
return JSONResponse(status)
16197

98+
@contextlib.asynccontextmanager
99+
async def lifespan(_: Starlette) -> AsyncIterator[None]:
100+
"""Context manager for session manager."""
101+
async with http.run():
102+
logger.info("Application started with StreamableHTTP session manager!")
103+
try:
104+
yield
105+
finally:
106+
logger.info("Application shutting down...")
107+
162108
middleware: list[Middleware] = []
163109
if allow_origins is not None:
164110
middleware.append(

tests/test_mcp_server.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
import pytest
88
import uvicorn
9-
from mcp import types
109
from mcp.client.session import ClientSession
1110
from mcp.client.sse import sse_client
1211
from mcp.client.streamable_http import streamablehttp_client
13-
from mcp.server import Server
12+
from mcp.server import FastMCP
13+
from mcp.types import TextContent
1414

1515
from mcp_proxy.mcp_server import create_starlette_app
1616

@@ -42,19 +42,32 @@ def url(self) -> str:
4242
return f"http://{hostport[0]}:{hostport[1]}"
4343

4444

45-
@pytest.mark.asyncio
46-
async def test_create_starlette_app() -> None:
47-
"""Test basic glue code for the SSE transport and a fake MCP server."""
48-
mcp_server: Server[object] = Server("prompt-server")
45+
def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
46+
"""Create a BackgroundServer instance with specified parameters."""
47+
mcp = FastMCP("TestServer")
48+
49+
@mcp.prompt(name="prompt1")
50+
async def list_prompts() -> str:
51+
return "hello world"
4952

50-
@mcp_server.list_prompts() # type: ignore[no-untyped-call,misc]
51-
async def list_prompts() -> list[types.Prompt]:
52-
return [types.Prompt(name="prompt1")]
53+
@mcp.tool(name="echo")
54+
async def call_tool(message: str) -> str:
55+
return f"Echo: {message}"
5356

54-
app = create_starlette_app(mcp_server, allow_origins=["*"])
57+
app = create_starlette_app(
58+
mcp._mcp_server, # noqa: SLF001
59+
allow_origins=["*"],
60+
**kwargs,
61+
)
5562

5663
config = uvicorn.Config(app, port=0, log_level="info")
57-
server = BackgroundServer(config)
64+
return BackgroundServer(config)
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_sse_transport() -> None:
69+
"""Test basic glue code for the SSE transport and a fake MCP server."""
70+
server = make_background_server(debug=True)
5871
async with server.run_in_background():
5972
sse_url = f"{server.url}/sse"
6073
async with sse_client(url=sse_url) as streams, ClientSession(*streams) as session:
@@ -63,6 +76,33 @@ async def list_prompts() -> list[types.Prompt]:
6376
assert len(response.prompts) == 1
6477
assert response.prompts[0].name == "prompt1"
6578

79+
80+
@pytest.mark.asyncio
81+
async def test_http_transport() -> None:
82+
"""Test HTTP transport layer functionality."""
83+
server = make_background_server(debug=True)
84+
async with server.run_in_background():
85+
http_url = f"{server.url}/mcp/"
86+
async with (
87+
streamablehttp_client(url=http_url) as (read, write, _),
88+
ClientSession(read, write) as session,
89+
):
90+
await session.initialize()
91+
response = await session.list_prompts()
92+
assert len(response.prompts) == 1
93+
assert response.prompts[0].name == "prompt1"
94+
95+
for i in range(3):
96+
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
97+
assert len(tool_result.content) == 1
98+
assert isinstance(tool_result.content[0], TextContent)
99+
assert tool_result.content[0].text == f"Echo: test_{i}"
100+
101+
102+
async def test_stateless_http_transport() -> None:
103+
"""Test stateless HTTP transport functionality."""
104+
server = make_background_server(debug=True, stateless=True)
105+
async with server.run_in_background():
66106
http_url = f"{server.url}/mcp/"
67107
async with (
68108
streamablehttp_client(url=http_url) as (read, write, _),
@@ -72,3 +112,9 @@ async def list_prompts() -> list[types.Prompt]:
72112
response = await session.list_prompts()
73113
assert len(response.prompts) == 1
74114
assert response.prompts[0].name == "prompt1"
115+
116+
for i in range(3):
117+
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
118+
assert len(tool_result.content) == 1
119+
assert isinstance(tool_result.content[0], TextContent)
120+
assert tool_result.content[0].text == f"Echo: test_{i}"

0 commit comments

Comments
 (0)