diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index b0ceb4acca..03a57c6bea 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,15 +1,29 @@ +import asyncio import warnings -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Mapping, Sequence +from datetime import timedelta +from json.decoder import JSONDecodeError from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast from typing_extensions import TypeGuard -from quart import Request, Response, request +from quart import Quart, Request, Response, request, websocket +from quart.ctx import has_websocket_context from quart.views import View -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, +) +from strawberry.http.exceptions import ( + HTTPException, + NonJsonMessageReceived, + NonTextMessageReceived, + WebSocketDisconnected, +) from strawberry.http.ides import GraphQL_IDE from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import Context, RootValue +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL if TYPE_CHECKING: from quart.typing import ResponseReturnValue @@ -46,6 +60,34 @@ async def get_form_data(self) -> FormData: return FormData(files=files, form=form) +class QuartWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, view: AsyncBaseHTTPView, request, ws) -> None: + super().__init__(view) + self.ws = websocket + + async def iter_json( + self, *, ignore_parsing_errors: bool = False + ) -> AsyncGenerator[object, None]: + while True: + message = await self.ws.receive() + if type(message) is bytes: + raise NonTextMessageReceived + try: + yield self.view.decode_json(message) + except JSONDecodeError as e: + if not ignore_parsing_errors: + raise NonJsonMessageReceived from e + + async def send_json(self, message: Mapping[str, object]) -> None: + try: + await self.ws.send(self.view.encode_json(message)) + except asyncio.CancelledError as exc: + raise WebSocketDisconnected from exc + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code, reason=reason) + + class GraphQLView( AsyncBaseHTTPView[ Request, Response, Response, Request, Response, Context, RootValue @@ -55,6 +97,7 @@ class GraphQLView( methods: ClassVar[list[str]] = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = QuartHTTPRequestAdapter + websocket_adapter_class = QuartWebSocketAdapter def __init__( self, @@ -62,10 +105,23 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + keep_alive: bool = True, + keep_alive_interval: float = 1, + debug: bool = False, + subscription_protocols: Sequence[str] = [ + GRAPHQL_TRANSPORT_WS_PROTOCOL, + GRAPHQL_WS_PROTOCOL, + ], + connection_init_wait_timeout: timedelta = timedelta(minutes=1), multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get + self.keep_alive = keep_alive + self.keep_alive_interval = keep_alive_interval + self.debug = debug + self.subscription_protocols = subscription_protocols + self.connection_init_wait_timeout = connection_init_wait_timeout self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: @@ -123,15 +179,53 @@ async def create_streaming_response( ) def is_websocket_request(self, request: Request) -> TypeGuard[Request]: - return False + if has_websocket_context(): + return True + + # Check if the request is a WebSocket upgrade request + connection = request.headers.get("Connection", "").lower() + upgrade = request.headers.get("Upgrade", "").lower() + + return "upgrade" in connection and "websocket" in upgrade async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: - raise NotImplementedError + # Get the requested protocols + protocols_header = websocket.headers.get("Sec-WebSocket-Protocol", "") + if not protocols_header: + return None + + # Find the first matching protocol + requested_protocols = [p.strip() for p in protocols_header.split(",")] + for protocol in requested_protocols: + if protocol in self.subscription_protocols: + return protocol + + return None async def create_websocket_response( self, request: Request, subprotocol: Optional[str] ) -> Response: - raise NotImplementedError + await websocket.accept(subprotocol=subprotocol) + # Return the current websocket context as the "response" + return None + + @classmethod + def register_route(cls, app: Quart, rule_name: str, path: str, **kwargs): + """Helper method to register both HTTP and WebSocket handlers for a given path. + + Args: + app: The Quart application + rule_name: The name of the rule + path: The path to register the handlers for + **kwargs: Parameters to pass to the GraphQLView constructor + """ + # Register both HTTP and WebSocket handler at the same path + view_func = cls.as_view(rule_name, **kwargs) + app.add_url_rule(path, view_func=view_func, methods=["GET", "POST"]) + + # Register the WebSocket handler using the same view function + # Quart will handle routing based on the WebSocket upgrade header + app.add_url_rule(path, view_func=view_func, methods=["GET"], websocket=True) __all__ = ["GraphQLView"] diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 1711e58b45..a757c6f7e2 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -1,30 +1,51 @@ +import asyncio +import contextlib import json import urllib.parse +from collections.abc import AsyncGenerator, Mapping from io import BytesIO -from typing import Any, Optional +from typing import Any, Optional, Union from typing_extensions import Literal +from asgiref.typing import ASGISendEvent +from hypercorn.typing import WebsocketScope + from quart import Quart from quart import Request as QuartRequest from quart import Response as QuartResponse from quart.datastructures import FileStorage +from quart.testing.connections import TestWebsocketConnection +from quart.typing import TestWebsocketConnectionProtocol +from quart.utils import decode_headers +from strawberry.exceptions import ConnectionRejectionError from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.quart.views import GraphQLView as BaseGraphQLView from strawberry.types import ExecutionResult +from strawberry.types.unset import UNSET, UnsetType from tests.http.context import get_context from tests.views.schema import Query, schema -from .base import JSON, HttpClient, Response, ResultOverrideFunction +from .base import ( + JSON, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, + HttpClient, + Message, + Response, + ResultOverrideFunction, + WebSocketClient, +) class GraphQLView(BaseGraphQLView[dict[str, object], object]): methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"] - + graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler + graphql_ws_handler_class = DebuggableGraphQLWSHandler result_override: ResultOverrideFunction = None def __init__(self, *args: Any, **kwargs: Any): - self.result_override = kwargs.pop("result_override") + self.result_override = kwargs.pop("result_override", None) super().__init__(*args, **kwargs) async def get_root_value(self, request: QuartRequest) -> Query: @@ -46,6 +67,28 @@ async def process_result( return await super().process_result(request, result) + async def on_ws_connect( + self, context: dict[str, object] + ) -> Union[UnsetType, None, dict[str, object]]: + connection_params = context["connection_params"] + + if isinstance(connection_params, dict): + if connection_params.get("test-reject"): + if "err-payload" in connection_params: + raise ConnectionRejectionError(connection_params["err-payload"]) + raise ConnectionRejectionError + + if connection_params.get("test-accept"): + if "ack-payload" in connection_params: + return connection_params["ack-payload"] + return UNSET + + if connection_params.get("test-modify"): + connection_params["modified"] = True + return UNSET + + return await super().on_ws_connect(context) + class QuartHttpClient(HttpClient): def __init__( @@ -73,6 +116,23 @@ def __init__( "/graphql", view_func=view, ) + self.app.add_url_rule( + "/graphql", view_func=view, methods=["GET"], websocket=True + ) + + def create_app(self, **kwargs: Any) -> None: + self.app = Quart(__name__) + self.app.debug = True + + view = GraphQLView.as_view("graphql_view", schema=schema, **kwargs) + + self.app.add_url_rule( + "/graphql", + view_func=view, + ) + self.app.add_url_rule( + "/graphql", view_func=view, methods=["GET"], websocket=True + ) async def _graphql_request( self, @@ -140,3 +200,100 @@ async def post( return await self.request( url, "post", **{k: v for k, v in kwargs.items() if v is not None} ) + + @contextlib.asynccontextmanager + async def ws_connect( + self, + url: str, + *, + protocols: list[str], + ) -> AsyncGenerator[WebSocketClient, None]: + headers = { + "sec-websocket-protocol": ", ".join(protocols), + } + async with self.app.test_app() as test_app: + client = test_app.test_client() + client.websocket_connection_class = QuartTestWebsocketConnection + async with client.websocket( + url, headers=headers, subprotocols=protocols + ) as ws: + yield QuartWebSocketClient(ws) + + +class QuartTestWebsocketConnection(TestWebsocketConnection): + def __init__(self, app: Quart, scope: WebsocketScope) -> None: + scope["asgi"] = {"spec_version": "2.3"} + super().__init__(app, scope) + + async def _asgi_send(self, message: ASGISendEvent) -> None: + if message["type"] == "websocket.accept": + self.accepted = True + elif message["type"] == "websocket.send": + await self._receive_queue.put(message.get("bytes") or message.get("text")) + elif message["type"] == "websocket.http.response.start": + self.headers = decode_headers(message["headers"]) + self.status_code = message["status"] + elif message["type"] == "websocket.http.response.body": + self.response_data.extend(message["body"]) + elif message["type"] == "websocket.close": + await self._receive_queue.put(json.dumps(message)) + + +class QuartWebSocketClient(WebSocketClient): + def __init__(self, ws: TestWebsocketConnectionProtocol): + self.ws = ws + self._closed: bool = False + self._close_code: Optional[int] = None + self._close_reason: Optional[str] = None + + async def send_text(self, payload: str) -> None: + await self.ws.send(payload) + + async def send_json(self, payload: Mapping[str, object]) -> None: + await self.ws.send_json(payload) + + async def send_bytes(self, payload: bytes) -> None: + await self.ws.send(payload) + + async def receive(self, timeout: Optional[float] = None) -> Message: + if self._closed: + # if close was received via exception, fake it so that recv works + return Message( + type="websocket.close", data=self._close_code, extra=self._close_reason + ) + m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout) + if m["type"] == "websocket.close": + self._closed = True + self._close_code = m["code"] + self._close_reason = m.get("reason", None) + return Message(type=m["type"], data=m["code"], extra=m.get("reason", None)) + if m["type"] == "websocket.send": + return Message(type=m["type"], data=m["text"]) + if m["type"] == "connection_ack": + return Message(type=m["type"], data="") + return Message(type=m["type"], data=m["data"], extra=m["extra"]) + + async def receive_json(self, timeout: Optional[float] = None) -> Any: + m = await asyncio.wait_for(self.ws.receive_json(), timeout=timeout) + return m + + async def close(self) -> None: + await self.ws.close(1000) + self._closed = True + + @property + def accepted_subprotocol(self) -> Optional[str]: + return "" + + @property + def closed(self) -> bool: + return self._closed + + @property + def close_code(self) -> int: + assert self._close_code is not None + return self._close_code + + @property + def close_reason(self) -> Optional[str]: + return self._close_reason diff --git a/tests/websockets/conftest.py b/tests/websockets/conftest.py index 7b784c2168..9fd56317b2 100644 --- a/tests/websockets/conftest.py +++ b/tests/websockets/conftest.py @@ -14,6 +14,7 @@ def _get_http_client_classes() -> Generator[Any, None, None]: ("ChannelsHttpClient", "channels", [pytest.mark.channels]), ("FastAPIHttpClient", "fastapi", [pytest.mark.fastapi]), ("LitestarHttpClient", "litestar", [pytest.mark.litestar]), + ("QuartHttpClient", "quart", [pytest.mark.quart]), ]: try: client_class = getattr( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 2b5ea8afe4..3a7f94df0a 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -72,7 +72,6 @@ def assert_next( async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"}) await ws.receive(timeout=2)