diff --git a/CHANGES/10729.breaking.rst b/CHANGES/10729.breaking.rst new file mode 100644 index 00000000000..1925091ba50 --- /dev/null +++ b/CHANGES/10729.breaking.rst @@ -0,0 +1 @@ +Stopped processing non-``GET`` method in WebSocket connection handshake (this might confuse proxies) as demanded by :rfc:`6455#section-4.2.1` -- by :user:`pajod`. diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 9421dc2ac76..c084f214712 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -36,7 +36,7 @@ from .log import ws_logger from .streams import EofStream from .typedefs import JSONDecoder, JSONEncoder -from .web_exceptions import HTTPBadRequest, HTTPException +from .web_exceptions import HTTPBadRequest, HTTPException, HTTPMethodNotAllowed from .web_request import BaseRequest from .web_response import StreamResponse @@ -226,6 +226,9 @@ def _handshake( self, request: BaseRequest ) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]: headers = request.headers + if request.method != hdrs.METH_GET: + raise HTTPMethodNotAllowed(request.method, {hdrs.METH_GET}) + if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip(): raise HTTPBadRequest( text=( diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 139d5fa073e..0e65d72239b 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -215,6 +215,12 @@ def test_can_prepare_unknown_protocol(make_request: _RequestMaker) -> None: assert WebSocketReady(True, None) == ws.can_prepare(req) +def test_can_prepare_invalid_method(make_request: _RequestMaker) -> None: + req = make_request("POST", "/") + ws = web.WebSocketResponse() + assert WebSocketReady(False, None) == ws.can_prepare(req) + + def test_can_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) ws = web.WebSocketResponse() @@ -369,11 +375,11 @@ async def test_close_idempotent(make_request: _RequestMaker) -> None: assert close_code == 0 -async def test_prepare_post_method_ok(make_request: _RequestMaker) -> None: +async def test_prepare_invalid_method(make_request: _RequestMaker) -> None: req = make_request("POST", "/") ws = web.WebSocketResponse() - await ws.prepare(req) - assert ws.prepared + with pytest.raises(web.HTTPMethodNotAllowed): + await ws.prepare(req) async def test_prepare_without_upgrade(make_request: _RequestMaker) -> None: diff --git a/tests/test_websocket_handshake.py b/tests/test_websocket_handshake.py index e069795af73..d9404019529 100644 --- a/tests/test_websocket_handshake.py +++ b/tests/test_websocket_handshake.py @@ -40,11 +40,39 @@ def gen_ws_headers( return hdrs, key +async def test_not_get() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request("POST", "/") + with pytest.raises(web.HTTPMethodNotAllowed): + await ws.prepare(req) + + +async def test_inappropriate_method() -> None: + ws = web.WebSocketResponse() + req = make_mocked_request( + "HEAD", + "/", + headers=[ + ("Upgrade", "websocket"), + # expect refusal; not a 1xx status (or two) + ("Expect", "100-continue"), + ("Expect", "100-continue"), + ], + ) + with pytest.raises(web.HTTPMethodNotAllowed) as ctx: + await ws.prepare(req) + assert ctx.value.method == "HEAD" + assert ctx.value.allowed_methods == {"GET"} + assert ctx.value.status == 405 + + async def test_no_upgrade() -> None: ws = web.WebSocketResponse() req = make_mocked_request("GET", "/") - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "UPGRADE" in ctx.value.text + assert ctx.value.status == 400 async def test_no_connection() -> None: @@ -52,8 +80,10 @@ async def test_no_connection() -> None: req = make_mocked_request( "GET", "/", headers={"Upgrade": "websocket", "Connection": "keep-alive"} ) - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "CONNECTION" in ctx.value.text + assert ctx.value.status == 400 async def test_protocol_version_unset() -> None: @@ -61,8 +91,10 @@ async def test_protocol_version_unset() -> None: req = make_mocked_request( "GET", "/", headers={"Upgrade": "websocket", "Connection": "upgrade"} ) - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "version" in ctx.value.text + assert ctx.value.status == 400 async def test_protocol_version_not_supported() -> None: @@ -76,8 +108,10 @@ async def test_protocol_version_not_supported() -> None: "Sec-Websocket-Version": "1", }, ) - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "version" in ctx.value.text + assert ctx.value.status == 400 async def test_protocol_key_not_present() -> None: @@ -91,8 +125,10 @@ async def test_protocol_key_not_present() -> None: "Sec-Websocket-Version": "13", }, ) - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "Handshake" in ctx.value.text + assert ctx.value.status == 400 async def test_protocol_key_invalid() -> None: @@ -107,8 +143,10 @@ async def test_protocol_key_invalid() -> None: "Sec-Websocket-Key": "123", }, ) - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "Handshake" in ctx.value.text + assert ctx.value.status == 400 async def test_protocol_key_bad_size() -> None: @@ -125,8 +163,10 @@ async def test_protocol_key_bad_size() -> None: "Sec-Websocket-Key": val, }, ) - with pytest.raises(web.HTTPBadRequest): + with pytest.raises(web.HTTPBadRequest) as ctx: await ws.prepare(req) + assert ctx.value.text and "Handshake" in ctx.value.text + assert ctx.value.status == 400 async def test_handshake_ok() -> None: