Skip to content

Commit 64e6eda

Browse files
Add support for WebSockets over HTTP/2 (#2894)
1 parent d783a8b commit 64e6eda

File tree

14 files changed

+375
-87
lines changed

14 files changed

+375
-87
lines changed

axum/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
- **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many`
1616
to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645])
1717
- **change:** Update minimum rust version to 1.75 ([#2943])
18+
- **added:** Add support WebSockets over HTTP/2.
19+
They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`.
1820

1921
[#2473]: https://github.com/tokio-rs/axum/pull/2473
2022
[#2645]: https://github.com/tokio-rs/axum/pull/2645

axum/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ features = [
112112
[dev-dependencies]
113113
anyhow = "1.0"
114114
axum-macros = { path = "../axum-macros", features = ["__private"] }
115+
hyper = { version = "1.1.0", features = ["client"] }
115116
quickcheck = "1.0"
116117
quickcheck_macros = "1.0"
117118
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }

axum/src/extract/ws.rs

Lines changed: 156 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
//! ```
66
//! use axum::{
77
//! extract::ws::{WebSocketUpgrade, WebSocket},
8-
//! routing::get,
8+
//! routing::any,
99
//! response::{IntoResponse, Response},
1010
//! Router,
1111
//! };
1212
//!
13-
//! let app = Router::new().route("/ws", get(handler));
13+
//! let app = Router::new().route("/ws", any(handler));
1414
//!
1515
//! async fn handler(ws: WebSocketUpgrade) -> Response {
1616
//! ws.on_upgrade(handle_socket)
@@ -40,7 +40,7 @@
4040
//! use axum::{
4141
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
4242
//! response::Response,
43-
//! routing::get,
43+
//! routing::any,
4444
//! Router,
4545
//! };
4646
//!
@@ -58,7 +58,7 @@
5858
//! }
5959
//!
6060
//! let app = Router::new()
61-
//! .route("/ws", get(handler))
61+
//! .route("/ws", any(handler))
6262
//! .with_state(AppState { /* ... */ });
6363
//! # let _: Router = app;
6464
//! ```
@@ -101,7 +101,7 @@ use futures_util::{
101101
use http::{
102102
header::{self, HeaderMap, HeaderName, HeaderValue},
103103
request::Parts,
104-
Method, StatusCode,
104+
Method, StatusCode, Version,
105105
};
106106
use hyper_util::rt::TokioIo;
107107
use sha1::{Digest, Sha1};
@@ -121,17 +121,20 @@ use tokio_tungstenite::{
121121

122122
/// Extractor for establishing WebSocket connections.
123123
///
124-
/// Note: This extractor requires the request method to be `GET` so it should
125-
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
126-
/// rejected.
124+
/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
125+
/// in later versions, `CONNECT` is used instead.
126+
/// To support both, it should be used with [`any`](crate::routing::any).
127127
///
128128
/// See the [module docs](self) for an example.
129+
///
130+
/// [`MethodFilter`]: crate::routing::MethodFilter
129131
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
130132
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
131133
config: WebSocketConfig,
132134
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
133135
protocol: Option<HeaderValue>,
134-
sec_websocket_key: HeaderValue,
136+
/// `None` if HTTP/2+ WebSockets are used.
137+
sec_websocket_key: Option<HeaderValue>,
135138
on_upgrade: hyper::upgrade::OnUpgrade,
136139
on_failed_upgrade: F,
137140
sec_websocket_protocol: Option<HeaderValue>,
@@ -212,12 +215,12 @@ impl<F> WebSocketUpgrade<F> {
212215
/// ```
213216
/// use axum::{
214217
/// extract::ws::{WebSocketUpgrade, WebSocket},
215-
/// routing::get,
218+
/// routing::any,
216219
/// response::{IntoResponse, Response},
217220
/// Router,
218221
/// };
219222
///
220-
/// let app = Router::new().route("/ws", get(handler));
223+
/// let app = Router::new().route("/ws", any(handler));
221224
///
222225
/// async fn handler(ws: WebSocketUpgrade) -> Response {
223226
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
@@ -329,25 +332,34 @@ impl<F> WebSocketUpgrade<F> {
329332
callback(socket).await;
330333
});
331334

332-
#[allow(clippy::declare_interior_mutable_const)]
333-
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
334-
#[allow(clippy::declare_interior_mutable_const)]
335-
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
336-
337-
let mut builder = Response::builder()
338-
.status(StatusCode::SWITCHING_PROTOCOLS)
339-
.header(header::CONNECTION, UPGRADE)
340-
.header(header::UPGRADE, WEBSOCKET)
341-
.header(
342-
header::SEC_WEBSOCKET_ACCEPT,
343-
sign(self.sec_websocket_key.as_bytes()),
344-
);
345-
346-
if let Some(protocol) = self.protocol {
347-
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
348-
}
335+
if let Some(sec_websocket_key) = &self.sec_websocket_key {
336+
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
337+
338+
#[allow(clippy::declare_interior_mutable_const)]
339+
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
340+
#[allow(clippy::declare_interior_mutable_const)]
341+
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
342+
343+
let mut builder = Response::builder()
344+
.status(StatusCode::SWITCHING_PROTOCOLS)
345+
.header(header::CONNECTION, UPGRADE)
346+
.header(header::UPGRADE, WEBSOCKET)
347+
.header(
348+
header::SEC_WEBSOCKET_ACCEPT,
349+
sign(sec_websocket_key.as_bytes()),
350+
);
351+
352+
if let Some(protocol) = self.protocol {
353+
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
354+
}
349355

350-
builder.body(Body::empty()).unwrap()
356+
builder.body(Body::empty()).unwrap()
357+
} else {
358+
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
359+
// with a 2XX with an empty body:
360+
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
361+
Response::new(Body::empty())
362+
}
351363
}
352364
}
353365

@@ -387,28 +399,49 @@ where
387399
type Rejection = WebSocketUpgradeRejection;
388400

389401
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
390-
if parts.method != Method::GET {
391-
return Err(MethodNotGet.into());
392-
}
402+
let sec_websocket_key = if parts.version <= Version::HTTP_11 {
403+
if parts.method != Method::GET {
404+
return Err(MethodNotGet.into());
405+
}
393406

394-
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
395-
return Err(InvalidConnectionHeader.into());
396-
}
407+
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
408+
return Err(InvalidConnectionHeader.into());
409+
}
397410

398-
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
399-
return Err(InvalidUpgradeHeader.into());
400-
}
411+
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
412+
return Err(InvalidUpgradeHeader.into());
413+
}
414+
415+
Some(
416+
parts
417+
.headers
418+
.get(header::SEC_WEBSOCKET_KEY)
419+
.ok_or(WebSocketKeyHeaderMissing)?
420+
.clone(),
421+
)
422+
} else {
423+
if parts.method != Method::CONNECT {
424+
return Err(MethodNotConnect.into());
425+
}
426+
427+
// if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
428+
// with.
429+
#[cfg(feature = "http2")]
430+
if parts
431+
.extensions
432+
.get::<hyper::ext::Protocol>()
433+
.map_or(true, |p| p.as_str() != "websocket")
434+
{
435+
return Err(InvalidProtocolPseudoheader.into());
436+
}
437+
438+
None
439+
};
401440

402441
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
403442
return Err(InvalidWebSocketVersionHeader.into());
404443
}
405444

406-
let sec_websocket_key = parts
407-
.headers
408-
.get(header::SEC_WEBSOCKET_KEY)
409-
.ok_or(WebSocketKeyHeaderMissing)?
410-
.clone();
411-
412445
let on_upgrade = parts
413446
.extensions
414447
.remove::<hyper::upgrade::OnUpgrade>()
@@ -706,6 +739,13 @@ pub mod rejection {
706739
pub struct MethodNotGet;
707740
}
708741

742+
define_rejection! {
743+
#[status = METHOD_NOT_ALLOWED]
744+
#[body = "Request method must be `CONNECT`"]
745+
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
746+
pub struct MethodNotConnect;
747+
}
748+
709749
define_rejection! {
710750
#[status = BAD_REQUEST]
711751
#[body = "Connection header did not include 'upgrade'"]
@@ -720,6 +760,13 @@ pub mod rejection {
720760
pub struct InvalidUpgradeHeader;
721761
}
722762

763+
define_rejection! {
764+
#[status = BAD_REQUEST]
765+
#[body = "`:protocol` pseudo-header did not include 'websocket'"]
766+
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
767+
pub struct InvalidProtocolPseudoheader;
768+
}
769+
723770
define_rejection! {
724771
#[status = BAD_REQUEST]
725772
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
@@ -755,8 +802,10 @@ pub mod rejection {
755802
/// extractor can fail.
756803
pub enum WebSocketUpgradeRejection {
757804
MethodNotGet,
805+
MethodNotConnect,
758806
InvalidConnectionHeader,
759807
InvalidUpgradeHeader,
808+
InvalidProtocolPseudoheader,
760809
InvalidWebSocketVersionHeader,
761810
WebSocketKeyHeaderMissing,
762811
ConnectionNotUpgradable,
@@ -838,14 +887,18 @@ mod tests {
838887
use std::future::ready;
839888

840889
use super::*;
841-
use crate::{routing::get, test_helpers::spawn_service, Router};
890+
use crate::{routing::any, test_helpers::spawn_service, Router};
842891
use http::{Request, Version};
892+
use http_body_util::BodyExt as _;
893+
use hyper_util::rt::TokioExecutor;
894+
use tokio::io::{AsyncRead, AsyncWrite};
895+
use tokio::net::TcpStream;
843896
use tokio_tungstenite::tungstenite;
844897
use tower::ServiceExt;
845898

846899
#[crate::test]
847900
async fn rejects_http_1_0_requests() {
848-
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
901+
let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
849902
let rejection = ws.unwrap_err();
850903
assert!(matches!(
851904
rejection,
@@ -874,7 +927,7 @@ mod tests {
874927
async fn handler(ws: WebSocketUpgrade) -> Response {
875928
ws.on_upgrade(|_| async {})
876929
}
877-
let _: Router = Router::new().route("/", get(handler));
930+
let _: Router = Router::new().route("/", any(handler));
878931
}
879932

880933
#[allow(dead_code)]
@@ -883,16 +936,61 @@ mod tests {
883936
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
884937
.on_upgrade(|_| async {})
885938
}
886-
let _: Router = Router::new().route("/", get(handler));
939+
let _: Router = Router::new().route("/", any(handler));
887940
}
888941

889942
#[crate::test]
890943
async fn integration_test() {
891-
let app = Router::new().route(
892-
"/echo",
893-
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
894-
);
944+
let addr = spawn_service(echo_app());
945+
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
946+
.await
947+
.unwrap();
948+
test_echo_app(socket).await;
949+
}
950+
951+
#[crate::test]
952+
#[cfg(feature = "http2")]
953+
async fn http2() {
954+
let addr = spawn_service(echo_app());
955+
let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
956+
let (mut send_request, conn) =
957+
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
958+
.handshake(io)
959+
.await
960+
.unwrap();
961+
962+
// Wait a little for the SETTINGS frame to go through…
963+
for _ in 0..10 {
964+
tokio::task::yield_now().await;
965+
}
966+
assert!(conn.is_extended_connect_protocol_enabled());
967+
tokio::spawn(async {
968+
conn.await.unwrap();
969+
});
895970

971+
let req = Request::builder()
972+
.method(Method::CONNECT)
973+
.extension(hyper::ext::Protocol::from_static("websocket"))
974+
.uri("/echo")
975+
.header("sec-websocket-version", "13")
976+
.header("Host", "server.example.com")
977+
.body(Body::empty())
978+
.unwrap();
979+
980+
let response = send_request.send_request(req).await.unwrap();
981+
let status = response.status();
982+
if status != 200 {
983+
let body = response.into_body().collect().await.unwrap().to_bytes();
984+
let body = std::str::from_utf8(&body).unwrap();
985+
panic!("response status was {}: {body}", status);
986+
}
987+
let upgraded = hyper::upgrade::on(response).await.unwrap();
988+
let upgraded = TokioIo::new(upgraded);
989+
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
990+
test_echo_app(socket).await;
991+
}
992+
993+
fn echo_app() -> Router {
896994
async fn handle_socket(mut socket: WebSocket) {
897995
while let Some(Ok(msg)) = socket.recv().await {
898996
match msg {
@@ -908,11 +1006,13 @@ mod tests {
9081006
}
9091007
}
9101008

911-
let addr = spawn_service(app);
912-
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
913-
.await
914-
.unwrap();
1009+
Router::new().route(
1010+
"/echo",
1011+
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
1012+
)
1013+
}
9151014

1015+
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
9161016
let input = tungstenite::Message::Text("foobar".to_owned());
9171017
socket.send(input.clone()).await.unwrap();
9181018
let output = socket.next().await.unwrap().unwrap();

axum/src/routing/method_routing.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,13 +1035,11 @@ where
10351035
match $svc {
10361036
MethodEndpoint::None => {}
10371037
MethodEndpoint::Route(route) => {
1038-
return RouteFuture::from_future(route.clone().oneshot_inner($req))
1039-
.strip_body($method == Method::HEAD);
1038+
return route.clone().oneshot_inner($req);
10401039
}
10411040
MethodEndpoint::BoxedHandler(handler) => {
1042-
let route = handler.clone().into_route(state);
1043-
return RouteFuture::from_future(route.clone().oneshot_inner($req))
1044-
.strip_body($method == Method::HEAD);
1041+
let mut route = handler.clone().into_route(state);
1042+
return route.oneshot_inner($req);
10451043
}
10461044
}
10471045
}

0 commit comments

Comments
 (0)