|
| 1 | +"""Model Context Protocol transport portocol for Server Sent Events (SSE). |
| 2 | +
|
| 3 | +This registers HTTP endpoints that supports SSE as a transport layer |
| 4 | +for the Model Context Protocol. There are two HTTP endpoints: |
| 5 | +
|
| 6 | +- /mcp_server/sse: The SSE endpoint that is used to establish a session |
| 7 | + with the client and glue to the MCP server. This is used to push responses |
| 8 | + to the client. |
| 9 | +- /mcp_server/messages: The endpoint that is used by the client to send |
| 10 | + POST requests with new requests for the MCP server. The request contains |
| 11 | + a session identifier. The response to the client is passed over the SSE |
| 12 | + session started on the other endpoint. |
| 13 | +
|
| 14 | +See https://modelcontextprotocol.io/docs/concepts/transports |
| 15 | +""" |
| 16 | + |
| 17 | +import logging |
| 18 | + |
| 19 | +from aiohttp import web |
| 20 | +from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound |
| 21 | +from aiohttp_sse import sse_response |
| 22 | +import anyio |
| 23 | +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
| 24 | +from mcp import types |
| 25 | + |
| 26 | +from homeassistant.components import conversation |
| 27 | +from homeassistant.components.http import KEY_HASS, HomeAssistantView |
| 28 | +from homeassistant.config_entries import ConfigEntryState |
| 29 | +from homeassistant.const import CONF_LLM_HASS_API |
| 30 | +from homeassistant.core import HomeAssistant, callback |
| 31 | +from homeassistant.helpers import llm |
| 32 | + |
| 33 | +from .const import DOMAIN |
| 34 | +from .server import create_server |
| 35 | +from .session import Session |
| 36 | +from .types import MCPServerConfigEntry |
| 37 | + |
| 38 | +_LOGGER = logging.getLogger(__name__) |
| 39 | + |
| 40 | +SSE_API = f"/{DOMAIN}/sse" |
| 41 | +MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}" |
| 42 | + |
| 43 | + |
| 44 | +@callback |
| 45 | +def async_register(hass: HomeAssistant) -> None: |
| 46 | + """Register the websocket API.""" |
| 47 | + hass.http.register_view(ModelContextProtocolSSEView()) |
| 48 | + hass.http.register_view(ModelContextProtocolMessagesView()) |
| 49 | + |
| 50 | + |
| 51 | +def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry: |
| 52 | + """Get the first enabled MCP server config entry. |
| 53 | +
|
| 54 | + The ConfigEntry contains a reference to the actual MCP server used to |
| 55 | + serve the Model Context Protocol. |
| 56 | +
|
| 57 | + Will raise an HTTP error if the expected configuration is not present. |
| 58 | + """ |
| 59 | + config_entries: list[MCPServerConfigEntry] = [ |
| 60 | + config_entry |
| 61 | + for config_entry in hass.config_entries.async_entries(DOMAIN) |
| 62 | + if config_entry.state == ConfigEntryState.LOADED |
| 63 | + ] |
| 64 | + if not config_entries: |
| 65 | + raise HTTPNotFound(body="Model Context Protocol server is not configured") |
| 66 | + if len(config_entries) > 1: |
| 67 | + raise HTTPNotFound(body="Found multiple Model Context Protocol configurations") |
| 68 | + return config_entries[0] |
| 69 | + |
| 70 | + |
| 71 | +class ModelContextProtocolSSEView(HomeAssistantView): |
| 72 | + """Model Context Protocol SSE endpoint.""" |
| 73 | + |
| 74 | + name = f"{DOMAIN}:sse" |
| 75 | + url = SSE_API |
| 76 | + |
| 77 | + async def get(self, request: web.Request) -> web.StreamResponse: |
| 78 | + """Process SSE messages for the Model Context Protocol. |
| 79 | +
|
| 80 | + This is a long running request for the lifetime of the client session |
| 81 | + and is the primary transport layer between the client and server. |
| 82 | +
|
| 83 | + Pairs of buffered streams act as a bridge between the transport protocol |
| 84 | + (SSE over HTTP views) and the Model Context Protocol. The MCP SDK |
| 85 | + manages all protocol details and invokes commands on our MCP server. |
| 86 | + """ |
| 87 | + hass = request.app[KEY_HASS] |
| 88 | + entry = async_get_config_entry(hass) |
| 89 | + session_manager = entry.runtime_data |
| 90 | + |
| 91 | + context = llm.LLMContext( |
| 92 | + platform=DOMAIN, |
| 93 | + context=self.context(request), |
| 94 | + user_prompt=None, |
| 95 | + language="*", |
| 96 | + assistant=conversation.DOMAIN, |
| 97 | + device_id=None, |
| 98 | + ) |
| 99 | + llm_api_id = entry.data[CONF_LLM_HASS_API] |
| 100 | + server = await create_server(hass, llm_api_id, context) |
| 101 | + options = await hass.async_add_executor_job( |
| 102 | + server.create_initialization_options # Reads package for version info |
| 103 | + ) |
| 104 | + |
| 105 | + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] |
| 106 | + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] |
| 107 | + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) |
| 108 | + |
| 109 | + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] |
| 110 | + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] |
| 111 | + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) |
| 112 | + |
| 113 | + async with ( |
| 114 | + sse_response(request) as response, |
| 115 | + session_manager.create(Session(read_stream_writer)) as session_id, |
| 116 | + ): |
| 117 | + session_uri = MESSAGES_API.format(session_id=session_id) |
| 118 | + _LOGGER.debug("Sending SSE endpoint: %s", session_uri) |
| 119 | + await response.send(session_uri, event="endpoint") |
| 120 | + |
| 121 | + async def sse_reader() -> None: |
| 122 | + """Forward MCP server responses to the client.""" |
| 123 | + async for message in write_stream_reader: |
| 124 | + _LOGGER.debug("Sending SSE message: %s", message) |
| 125 | + await response.send( |
| 126 | + message.model_dump_json(by_alias=True, exclude_none=True), |
| 127 | + event="message", |
| 128 | + ) |
| 129 | + |
| 130 | + async with anyio.create_task_group() as tg: |
| 131 | + tg.start_soon(sse_reader) |
| 132 | + await server.run(read_stream, write_stream, options) |
| 133 | + return response |
| 134 | + |
| 135 | + |
| 136 | +class ModelContextProtocolMessagesView(HomeAssistantView): |
| 137 | + """Model Context Protocol messages endpoint.""" |
| 138 | + |
| 139 | + name = f"{DOMAIN}:messages" |
| 140 | + url = MESSAGES_API |
| 141 | + |
| 142 | + async def post( |
| 143 | + self, |
| 144 | + request: web.Request, |
| 145 | + session_id: str, |
| 146 | + ) -> web.StreamResponse: |
| 147 | + """Process incoming messages for the Model Context Protocol. |
| 148 | +
|
| 149 | + The request passes a session ID which is used to identify the original |
| 150 | + SSE connection. This view parses incoming messagess from the transport |
| 151 | + layer then writes them to the MCP server stream for the session. |
| 152 | + """ |
| 153 | + hass = request.app[KEY_HASS] |
| 154 | + config_entry = async_get_config_entry(hass) |
| 155 | + |
| 156 | + session_manager = config_entry.runtime_data |
| 157 | + if (session := session_manager.get(session_id)) is None: |
| 158 | + _LOGGER.info("Could not find session ID: '%s'", session_id) |
| 159 | + raise HTTPNotFound(body=f"Could not find session ID '{session_id}'") |
| 160 | + |
| 161 | + json_data = await request.json() |
| 162 | + try: |
| 163 | + message = types.JSONRPCMessage.model_validate(json_data) |
| 164 | + except ValueError as err: |
| 165 | + _LOGGER.info("Failed to parse message: %s", err) |
| 166 | + raise HTTPBadRequest(body="Could not parse message") from err |
| 167 | + |
| 168 | + _LOGGER.debug("Received client message: %s", message) |
| 169 | + await session.read_stream_writer.send(message) |
| 170 | + return web.Response(status=200) |
0 commit comments