5
5
from collections .abc import AsyncIterator
6
6
from dataclasses import dataclass
7
7
from datetime import datetime , timezone
8
- from http import HTTPStatus
9
- from typing import Any , Literal
10
- from uuid import uuid4
8
+ from typing import Literal
11
9
12
- import anyio
13
10
import uvicorn
14
- from anyio .abc import TaskStatus
15
11
from mcp .client .session import ClientSession
16
12
from mcp .client .stdio import StdioServerParameters , stdio_client
17
13
from mcp .server import Server
18
14
from mcp .server .sse import SseServerTransport
19
- from mcp .server .streamable_http import StreamableHTTPServerTransport
15
+ from mcp .server .streamable_http_manager import StreamableHTTPSessionManager
20
16
from starlette .applications import Starlette
21
17
from starlette .middleware import Middleware
22
18
from starlette .middleware .cors import CORSMiddleware
28
24
from .proxy_server import create_proxy_server
29
25
30
26
logger = logging .getLogger (__name__ )
31
- # Global task group that will be initialized in the lifespan
32
- task_group = None
33
-
34
- MCP_SESSION_ID_HEADER = "mcp-session-id"
35
-
36
-
37
- @contextlib .asynccontextmanager
38
- async def lifespan (_ : Starlette ) -> AsyncIterator [None ]:
39
- """Application lifespan context manager for managing task group."""
40
- global task_group # noqa: PLW0603
41
-
42
- async with anyio .create_task_group () as tg :
43
- task_group = tg
44
- logger .info ("Application started, task group initialized!" )
45
- try :
46
- yield
47
- finally :
48
- logger .info ("Application shutting down, cleaning up resources..." )
49
- if task_group :
50
- tg .cancel_scope .cancel ()
51
- task_group = None
52
- logger .info ("Resources cleaned up successfully." )
53
27
54
28
55
29
@dataclass
@@ -58,13 +32,15 @@ class MCPServerSettings:
58
32
59
33
bind_host : str
60
34
port : int
35
+ stateless : bool = False
61
36
allow_origins : list [str ] | None = None
62
37
log_level : Literal ["DEBUG" , "INFO" , "WARNING" , "ERROR" , "CRITICAL" ] = "INFO"
63
38
64
39
65
- def create_starlette_app ( # noqa: C901, Refactor required for complexity
40
+ def create_starlette_app (
66
41
mcp_server : Server [object ],
67
42
* ,
43
+ stateless : bool = False ,
68
44
allow_origins : list [str ] | None = None ,
69
45
debug : bool = False ,
70
46
) -> Starlette :
@@ -97,57 +73,17 @@ async def handle_sse(request: Request) -> None:
97
73
mcp_server .create_initialization_options (),
98
74
)
99
75
100
- # Refer: https://github.com/modelcontextprotocol/python-sdk/blob/5d8eaf77be00dbd9b33a7fe1e38cb0da77e49401/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py
101
- # We need to store the server instances between requests
102
- server_instances : dict [str , Any ] = {}
103
- # Lock to prevent race conditions when creating new sessions
104
- session_creation_lock = anyio .Lock ()
76
+ # Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py
77
+ http = StreamableHTTPSessionManager (
78
+ app = mcp_server ,
79
+ event_store = None ,
80
+ json_response = True ,
81
+ stateless = stateless ,
82
+ )
105
83
106
84
async def handle_streamable_http (scope : Scope , receive : Receive , send : Send ) -> None :
107
85
_update_mcp_activity ()
108
- request = Request (scope , receive )
109
- request_mcp_session_id = request .headers .get (MCP_SESSION_ID_HEADER )
110
- if request_mcp_session_id is not None and request_mcp_session_id in server_instances :
111
- transport = server_instances [request_mcp_session_id ]
112
- logger .debug ("Session already exists, handling request directly" )
113
- await transport .handle_request (scope , receive , send )
114
- elif request_mcp_session_id is None :
115
- # try to establish new session
116
- logger .debug ("Creating new transport" )
117
- # Use lock to prevent race conditions when creating new sessions
118
- async with session_creation_lock :
119
- new_session_id = uuid4 ().hex
120
- http_transport = StreamableHTTPServerTransport (
121
- mcp_session_id = new_session_id ,
122
- is_json_response_enabled = True ,
123
- )
124
- server_instances [new_session_id ] = http_transport
125
- logger .info ("Created new transport with session ID: %s" , new_session_id )
126
-
127
- async def run_server (task_status : TaskStatus [Any ] | None = None ) -> None :
128
- async with http_transport .connect () as streams :
129
- read_stream , write_stream = streams
130
- if task_status :
131
- task_status .started ()
132
- await mcp_server .run (
133
- read_stream ,
134
- write_stream ,
135
- mcp_server .create_initialization_options (),
136
- )
137
-
138
- if not task_group :
139
- raise RuntimeError ("Task group is not initialized" )
140
-
141
- await task_group .start (run_server )
142
-
143
- # Handle the HTTP request and return the response
144
- await http_transport .handle_request (scope , receive , send )
145
- else :
146
- response = Response (
147
- "Bad Request: No valid session ID provided" ,
148
- status_code = HTTPStatus .BAD_REQUEST ,
149
- )
150
- await response (scope , receive , send )
86
+ await http .handle_request (scope , receive , send )
151
87
152
88
async def handle_status (_ : Request ) -> Response :
153
89
"""Health check and service usage monitoring endpoint.
@@ -159,6 +95,16 @@ async def handle_status(_: Request) -> Response:
159
95
"""
160
96
return JSONResponse (status )
161
97
98
+ @contextlib .asynccontextmanager
99
+ async def lifespan (_ : Starlette ) -> AsyncIterator [None ]:
100
+ """Context manager for session manager."""
101
+ async with http .run ():
102
+ logger .info ("Application started with StreamableHTTP session manager!" )
103
+ try :
104
+ yield
105
+ finally :
106
+ logger .info ("Application shutting down..." )
107
+
162
108
middleware : list [Middleware ] = []
163
109
if allow_origins is not None :
164
110
middleware .append (
0 commit comments