Skip to content

enforce rfc6455 section 4.2.1 requirement to refuse malformed Websocket handshake #10729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10729.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Stopped processing non-``GET`` method in WebSocket connection handshake (this might confuse proxies) as demanded by :rfc:`6455#section-4.2.1` -- by :user:`pajod`.
5 changes: 4 additions & 1 deletion aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .log import ws_logger
from .streams import EofStream
from .typedefs import JSONDecoder, JSONEncoder
from .web_exceptions import HTTPBadRequest, HTTPException
from .web_exceptions import HTTPBadRequest, HTTPException, HTTPMethodNotAllowed
from .web_request import BaseRequest
from .web_response import StreamResponse

Expand Down Expand Up @@ -226,6 +226,9 @@ def _handshake(
self, request: BaseRequest
) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]:
headers = request.headers
if request.method != hdrs.METH_GET:
raise HTTPMethodNotAllowed(request.method, {hdrs.METH_GET})

if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
raise HTTPBadRequest(
text=(
Expand Down
12 changes: 9 additions & 3 deletions tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ def test_can_prepare_unknown_protocol(make_request: _RequestMaker) -> None:
assert WebSocketReady(True, None) == ws.can_prepare(req)


def test_can_prepare_invalid_method(make_request: _RequestMaker) -> None:
req = make_request("POST", "/")
ws = web.WebSocketResponse()
assert WebSocketReady(False, None) == ws.can_prepare(req)


def test_can_prepare_without_upgrade(make_request: _RequestMaker) -> None:
req = make_request("GET", "/", headers=CIMultiDict({}))
ws = web.WebSocketResponse()
Expand Down Expand Up @@ -369,11 +375,11 @@ async def test_close_idempotent(make_request: _RequestMaker) -> None:
assert close_code == 0


async def test_prepare_post_method_ok(make_request: _RequestMaker) -> None:
async def test_prepare_invalid_method(make_request: _RequestMaker) -> None:
req = make_request("POST", "/")
ws = web.WebSocketResponse()
await ws.prepare(req)
assert ws.prepared
with pytest.raises(web.HTTPMethodNotAllowed):
await ws.prepare(req)


async def test_prepare_without_upgrade(make_request: _RequestMaker) -> None:
Expand Down
54 changes: 47 additions & 7 deletions tests/test_websocket_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,61 @@ def gen_ws_headers(
return hdrs, key


async def test_not_get() -> None:
ws = web.WebSocketResponse()
req = make_mocked_request("POST", "/")
with pytest.raises(web.HTTPMethodNotAllowed):
await ws.prepare(req)


async def test_inappropriate_method() -> None:
ws = web.WebSocketResponse()
req = make_mocked_request(
"HEAD",
"/",
headers=[
("Upgrade", "websocket"),
# expect refusal; not a 1xx status (or two)
("Expect", "100-continue"),
("Expect", "100-continue"),
],
)
with pytest.raises(web.HTTPMethodNotAllowed) as ctx:
await ws.prepare(req)
assert ctx.value.method == "HEAD"
assert ctx.value.allowed_methods == {"GET"}
assert ctx.value.status == 405


async def test_no_upgrade() -> None:
ws = web.WebSocketResponse()
req = make_mocked_request("GET", "/")
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "UPGRADE" in ctx.value.text
assert ctx.value.status == 400


async def test_no_connection() -> None:
ws = web.WebSocketResponse()
req = make_mocked_request(
"GET", "/", headers={"Upgrade": "websocket", "Connection": "keep-alive"}
)
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "CONNECTION" in ctx.value.text
assert ctx.value.status == 400


async def test_protocol_version_unset() -> None:
ws = web.WebSocketResponse()
req = make_mocked_request(
"GET", "/", headers={"Upgrade": "websocket", "Connection": "upgrade"}
)
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "version" in ctx.value.text
assert ctx.value.status == 400


async def test_protocol_version_not_supported() -> None:
Expand All @@ -76,8 +108,10 @@ async def test_protocol_version_not_supported() -> None:
"Sec-Websocket-Version": "1",
},
)
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "version" in ctx.value.text
assert ctx.value.status == 400


async def test_protocol_key_not_present() -> None:
Expand All @@ -91,8 +125,10 @@ async def test_protocol_key_not_present() -> None:
"Sec-Websocket-Version": "13",
},
)
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "Handshake" in ctx.value.text
assert ctx.value.status == 400


async def test_protocol_key_invalid() -> None:
Expand All @@ -107,8 +143,10 @@ async def test_protocol_key_invalid() -> None:
"Sec-Websocket-Key": "123",
},
)
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "Handshake" in ctx.value.text
assert ctx.value.status == 400


async def test_protocol_key_bad_size() -> None:
Expand All @@ -125,8 +163,10 @@ async def test_protocol_key_bad_size() -> None:
"Sec-Websocket-Key": val,
},
)
with pytest.raises(web.HTTPBadRequest):
with pytest.raises(web.HTTPBadRequest) as ctx:
await ws.prepare(req)
assert ctx.value.text and "Handshake" in ctx.value.text
assert ctx.value.status == 400


async def test_handshake_ok() -> None:
Expand Down
Loading