5
5
//! ```
6
6
//! use axum::{
7
7
//! extract::ws::{WebSocketUpgrade, WebSocket},
8
- //! routing::get ,
8
+ //! routing::any ,
9
9
//! response::{IntoResponse, Response},
10
10
//! Router,
11
11
//! };
12
12
//!
13
- //! let app = Router::new().route("/ws", get (handler));
13
+ //! let app = Router::new().route("/ws", any (handler));
14
14
//!
15
15
//! async fn handler(ws: WebSocketUpgrade) -> Response {
16
16
//! ws.on_upgrade(handle_socket)
40
40
//! use axum::{
41
41
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
42
42
//! response::Response,
43
- //! routing::get ,
43
+ //! routing::any ,
44
44
//! Router,
45
45
//! };
46
46
//!
58
58
//! }
59
59
//!
60
60
//! let app = Router::new()
61
- //! .route("/ws", get (handler))
61
+ //! .route("/ws", any (handler))
62
62
//! .with_state(AppState { /* ... */ });
63
63
//! # let _: Router = app;
64
64
//! ```
@@ -101,7 +101,7 @@ use futures_util::{
101
101
use http:: {
102
102
header:: { self , HeaderMap , HeaderName , HeaderValue } ,
103
103
request:: Parts ,
104
- Method , StatusCode ,
104
+ Method , StatusCode , Version ,
105
105
} ;
106
106
use hyper_util:: rt:: TokioIo ;
107
107
use sha1:: { Digest , Sha1 } ;
@@ -121,17 +121,20 @@ use tokio_tungstenite::{
121
121
122
122
/// Extractor for establishing WebSocket connections.
123
123
///
124
- /// Note: This extractor requires the request method to be `GET` so it should
125
- /// always be used with [`get`](crate::routing::get). Requests with other methods will be
126
- /// rejected .
124
+ /// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
125
+ /// in later versions, `CONNECT` is used instead.
126
+ /// To support both, it should be used with [`any`](crate::routing::any) .
127
127
///
128
128
/// See the [module docs](self) for an example.
129
+ ///
130
+ /// [`MethodFilter`]: crate::routing::MethodFilter
129
131
#[ cfg_attr( docsrs, doc( cfg( feature = "ws" ) ) ) ]
130
132
pub struct WebSocketUpgrade < F = DefaultOnFailedUpgrade > {
131
133
config : WebSocketConfig ,
132
134
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
133
135
protocol : Option < HeaderValue > ,
134
- sec_websocket_key : HeaderValue ,
136
+ /// `None` if HTTP/2+ WebSockets are used.
137
+ sec_websocket_key : Option < HeaderValue > ,
135
138
on_upgrade : hyper:: upgrade:: OnUpgrade ,
136
139
on_failed_upgrade : F ,
137
140
sec_websocket_protocol : Option < HeaderValue > ,
@@ -212,12 +215,12 @@ impl<F> WebSocketUpgrade<F> {
212
215
/// ```
213
216
/// use axum::{
214
217
/// extract::ws::{WebSocketUpgrade, WebSocket},
215
- /// routing::get ,
218
+ /// routing::any ,
216
219
/// response::{IntoResponse, Response},
217
220
/// Router,
218
221
/// };
219
222
///
220
- /// let app = Router::new().route("/ws", get (handler));
223
+ /// let app = Router::new().route("/ws", any (handler));
221
224
///
222
225
/// async fn handler(ws: WebSocketUpgrade) -> Response {
223
226
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
@@ -329,25 +332,34 @@ impl<F> WebSocketUpgrade<F> {
329
332
callback ( socket) . await ;
330
333
} ) ;
331
334
332
- #[ allow( clippy:: declare_interior_mutable_const) ]
333
- const UPGRADE : HeaderValue = HeaderValue :: from_static ( "upgrade" ) ;
334
- #[ allow( clippy:: declare_interior_mutable_const) ]
335
- const WEBSOCKET : HeaderValue = HeaderValue :: from_static ( "websocket" ) ;
336
-
337
- let mut builder = Response :: builder ( )
338
- . status ( StatusCode :: SWITCHING_PROTOCOLS )
339
- . header ( header:: CONNECTION , UPGRADE )
340
- . header ( header:: UPGRADE , WEBSOCKET )
341
- . header (
342
- header:: SEC_WEBSOCKET_ACCEPT ,
343
- sign ( self . sec_websocket_key . as_bytes ( ) ) ,
344
- ) ;
345
-
346
- if let Some ( protocol) = self . protocol {
347
- builder = builder. header ( header:: SEC_WEBSOCKET_PROTOCOL , protocol) ;
348
- }
335
+ if let Some ( sec_websocket_key) = & self . sec_websocket_key {
336
+ // If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
337
+
338
+ #[ allow( clippy:: declare_interior_mutable_const) ]
339
+ const UPGRADE : HeaderValue = HeaderValue :: from_static ( "upgrade" ) ;
340
+ #[ allow( clippy:: declare_interior_mutable_const) ]
341
+ const WEBSOCKET : HeaderValue = HeaderValue :: from_static ( "websocket" ) ;
342
+
343
+ let mut builder = Response :: builder ( )
344
+ . status ( StatusCode :: SWITCHING_PROTOCOLS )
345
+ . header ( header:: CONNECTION , UPGRADE )
346
+ . header ( header:: UPGRADE , WEBSOCKET )
347
+ . header (
348
+ header:: SEC_WEBSOCKET_ACCEPT ,
349
+ sign ( sec_websocket_key. as_bytes ( ) ) ,
350
+ ) ;
351
+
352
+ if let Some ( protocol) = self . protocol {
353
+ builder = builder. header ( header:: SEC_WEBSOCKET_PROTOCOL , protocol) ;
354
+ }
349
355
350
- builder. body ( Body :: empty ( ) ) . unwrap ( )
356
+ builder. body ( Body :: empty ( ) ) . unwrap ( )
357
+ } else {
358
+ // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
359
+ // with a 2XX with an empty body:
360
+ // <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
361
+ Response :: new ( Body :: empty ( ) )
362
+ }
351
363
}
352
364
}
353
365
@@ -387,28 +399,49 @@ where
387
399
type Rejection = WebSocketUpgradeRejection ;
388
400
389
401
async fn from_request_parts ( parts : & mut Parts , _state : & S ) -> Result < Self , Self :: Rejection > {
390
- if parts. method != Method :: GET {
391
- return Err ( MethodNotGet . into ( ) ) ;
392
- }
402
+ let sec_websocket_key = if parts. version <= Version :: HTTP_11 {
403
+ if parts. method != Method :: GET {
404
+ return Err ( MethodNotGet . into ( ) ) ;
405
+ }
393
406
394
- if !header_contains ( & parts. headers , header:: CONNECTION , "upgrade" ) {
395
- return Err ( InvalidConnectionHeader . into ( ) ) ;
396
- }
407
+ if !header_contains ( & parts. headers , header:: CONNECTION , "upgrade" ) {
408
+ return Err ( InvalidConnectionHeader . into ( ) ) ;
409
+ }
397
410
398
- if !header_eq ( & parts. headers , header:: UPGRADE , "websocket" ) {
399
- return Err ( InvalidUpgradeHeader . into ( ) ) ;
400
- }
411
+ if !header_eq ( & parts. headers , header:: UPGRADE , "websocket" ) {
412
+ return Err ( InvalidUpgradeHeader . into ( ) ) ;
413
+ }
414
+
415
+ Some (
416
+ parts
417
+ . headers
418
+ . get ( header:: SEC_WEBSOCKET_KEY )
419
+ . ok_or ( WebSocketKeyHeaderMissing ) ?
420
+ . clone ( ) ,
421
+ )
422
+ } else {
423
+ if parts. method != Method :: CONNECT {
424
+ return Err ( MethodNotConnect . into ( ) ) ;
425
+ }
426
+
427
+ // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin
428
+ // with.
429
+ #[ cfg( feature = "http2" ) ]
430
+ if parts
431
+ . extensions
432
+ . get :: < hyper:: ext:: Protocol > ( )
433
+ . map_or ( true , |p| p. as_str ( ) != "websocket" )
434
+ {
435
+ return Err ( InvalidProtocolPseudoheader . into ( ) ) ;
436
+ }
437
+
438
+ None
439
+ } ;
401
440
402
441
if !header_eq ( & parts. headers , header:: SEC_WEBSOCKET_VERSION , "13" ) {
403
442
return Err ( InvalidWebSocketVersionHeader . into ( ) ) ;
404
443
}
405
444
406
- let sec_websocket_key = parts
407
- . headers
408
- . get ( header:: SEC_WEBSOCKET_KEY )
409
- . ok_or ( WebSocketKeyHeaderMissing ) ?
410
- . clone ( ) ;
411
-
412
445
let on_upgrade = parts
413
446
. extensions
414
447
. remove :: < hyper:: upgrade:: OnUpgrade > ( )
@@ -706,6 +739,13 @@ pub mod rejection {
706
739
pub struct MethodNotGet ;
707
740
}
708
741
742
+ define_rejection ! {
743
+ #[ status = METHOD_NOT_ALLOWED ]
744
+ #[ body = "Request method must be `CONNECT`" ]
745
+ /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
746
+ pub struct MethodNotConnect ;
747
+ }
748
+
709
749
define_rejection ! {
710
750
#[ status = BAD_REQUEST ]
711
751
#[ body = "Connection header did not include 'upgrade'" ]
@@ -720,6 +760,13 @@ pub mod rejection {
720
760
pub struct InvalidUpgradeHeader ;
721
761
}
722
762
763
+ define_rejection ! {
764
+ #[ status = BAD_REQUEST ]
765
+ #[ body = "`:protocol` pseudo-header did not include 'websocket'" ]
766
+ /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
767
+ pub struct InvalidProtocolPseudoheader ;
768
+ }
769
+
723
770
define_rejection ! {
724
771
#[ status = BAD_REQUEST ]
725
772
#[ body = "`Sec-WebSocket-Version` header did not include '13'" ]
@@ -755,8 +802,10 @@ pub mod rejection {
755
802
/// extractor can fail.
756
803
pub enum WebSocketUpgradeRejection {
757
804
MethodNotGet ,
805
+ MethodNotConnect ,
758
806
InvalidConnectionHeader ,
759
807
InvalidUpgradeHeader ,
808
+ InvalidProtocolPseudoheader ,
760
809
InvalidWebSocketVersionHeader ,
761
810
WebSocketKeyHeaderMissing ,
762
811
ConnectionNotUpgradable ,
@@ -838,14 +887,18 @@ mod tests {
838
887
use std:: future:: ready;
839
888
840
889
use super :: * ;
841
- use crate :: { routing:: get , test_helpers:: spawn_service, Router } ;
890
+ use crate :: { routing:: any , test_helpers:: spawn_service, Router } ;
842
891
use http:: { Request , Version } ;
892
+ use http_body_util:: BodyExt as _;
893
+ use hyper_util:: rt:: TokioExecutor ;
894
+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
895
+ use tokio:: net:: TcpStream ;
843
896
use tokio_tungstenite:: tungstenite;
844
897
use tower:: ServiceExt ;
845
898
846
899
#[ crate :: test]
847
900
async fn rejects_http_1_0_requests ( ) {
848
- let svc = get ( |ws : Result < WebSocketUpgrade , WebSocketUpgradeRejection > | {
901
+ let svc = any ( |ws : Result < WebSocketUpgrade , WebSocketUpgradeRejection > | {
849
902
let rejection = ws. unwrap_err ( ) ;
850
903
assert ! ( matches!(
851
904
rejection,
@@ -874,7 +927,7 @@ mod tests {
874
927
async fn handler ( ws : WebSocketUpgrade ) -> Response {
875
928
ws. on_upgrade ( |_| async { } )
876
929
}
877
- let _: Router = Router :: new ( ) . route ( "/" , get ( handler) ) ;
930
+ let _: Router = Router :: new ( ) . route ( "/" , any ( handler) ) ;
878
931
}
879
932
880
933
#[ allow( dead_code) ]
@@ -883,16 +936,61 @@ mod tests {
883
936
ws. on_failed_upgrade ( |_error : Error | println ! ( "oops!" ) )
884
937
. on_upgrade ( |_| async { } )
885
938
}
886
- let _: Router = Router :: new ( ) . route ( "/" , get ( handler) ) ;
939
+ let _: Router = Router :: new ( ) . route ( "/" , any ( handler) ) ;
887
940
}
888
941
889
942
#[ crate :: test]
890
943
async fn integration_test ( ) {
891
- let app = Router :: new ( ) . route (
892
- "/echo" ,
893
- get ( |ws : WebSocketUpgrade | ready ( ws. on_upgrade ( handle_socket) ) ) ,
894
- ) ;
944
+ let addr = spawn_service ( echo_app ( ) ) ;
945
+ let ( socket, _response) = tokio_tungstenite:: connect_async ( format ! ( "ws://{addr}/echo" ) )
946
+ . await
947
+ . unwrap ( ) ;
948
+ test_echo_app ( socket) . await ;
949
+ }
950
+
951
+ #[ crate :: test]
952
+ #[ cfg( feature = "http2" ) ]
953
+ async fn http2 ( ) {
954
+ let addr = spawn_service ( echo_app ( ) ) ;
955
+ let io = TokioIo :: new ( TcpStream :: connect ( addr) . await . unwrap ( ) ) ;
956
+ let ( mut send_request, conn) =
957
+ hyper:: client:: conn:: http2:: Builder :: new ( TokioExecutor :: new ( ) )
958
+ . handshake ( io)
959
+ . await
960
+ . unwrap ( ) ;
961
+
962
+ // Wait a little for the SETTINGS frame to go through…
963
+ for _ in 0 ..10 {
964
+ tokio:: task:: yield_now ( ) . await ;
965
+ }
966
+ assert ! ( conn. is_extended_connect_protocol_enabled( ) ) ;
967
+ tokio:: spawn ( async {
968
+ conn. await . unwrap ( ) ;
969
+ } ) ;
895
970
971
+ let req = Request :: builder ( )
972
+ . method ( Method :: CONNECT )
973
+ . extension ( hyper:: ext:: Protocol :: from_static ( "websocket" ) )
974
+ . uri ( "/echo" )
975
+ . header ( "sec-websocket-version" , "13" )
976
+ . header ( "Host" , "server.example.com" )
977
+ . body ( Body :: empty ( ) )
978
+ . unwrap ( ) ;
979
+
980
+ let response = send_request. send_request ( req) . await . unwrap ( ) ;
981
+ let status = response. status ( ) ;
982
+ if status != 200 {
983
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
984
+ let body = std:: str:: from_utf8 ( & body) . unwrap ( ) ;
985
+ panic ! ( "response status was {}: {body}" , status) ;
986
+ }
987
+ let upgraded = hyper:: upgrade:: on ( response) . await . unwrap ( ) ;
988
+ let upgraded = TokioIo :: new ( upgraded) ;
989
+ let socket = WebSocketStream :: from_raw_socket ( upgraded, protocol:: Role :: Client , None ) . await ;
990
+ test_echo_app ( socket) . await ;
991
+ }
992
+
993
+ fn echo_app ( ) -> Router {
896
994
async fn handle_socket ( mut socket : WebSocket ) {
897
995
while let Some ( Ok ( msg) ) = socket. recv ( ) . await {
898
996
match msg {
@@ -908,11 +1006,13 @@ mod tests {
908
1006
}
909
1007
}
910
1008
911
- let addr = spawn_service ( app) ;
912
- let ( mut socket, _response) = tokio_tungstenite:: connect_async ( format ! ( "ws://{addr}/echo" ) )
913
- . await
914
- . unwrap ( ) ;
1009
+ Router :: new ( ) . route (
1010
+ "/echo" ,
1011
+ any ( |ws : WebSocketUpgrade | ready ( ws. on_upgrade ( handle_socket) ) ) ,
1012
+ )
1013
+ }
915
1014
1015
+ async fn test_echo_app < S : AsyncRead + AsyncWrite + Unpin > ( mut socket : WebSocketStream < S > ) {
916
1016
let input = tungstenite:: Message :: Text ( "foobar" . to_owned ( ) ) ;
917
1017
socket. send ( input. clone ( ) ) . await . unwrap ( ) ;
918
1018
let output = socket. next ( ) . await . unwrap ( ) . unwrap ( ) ;
0 commit comments