@@ -163,6 +163,8 @@ where
163
163
// requests don't hang forever.
164
164
TimeoutLayer :: new ( Duration :: from_secs ( app. settings . drain_timeout_secs ) ) ,
165
165
)
166
+ // add early validations
167
+ . layer ( middleware:: from_fn ( validate_request) )
166
168
// Add auth middleware to all user facing routes
167
169
. layer ( middleware:: from_fn_with_state (
168
170
app. settings . api_auth_token . clone ( ) ,
@@ -199,8 +201,30 @@ async fn graceful_shutdown(handle: Handle, duration_secs: u64) {
199
201
200
202
const PUBLISH_ENDPOINTS : [ & str ; 3 ] = [ "/v1/process/sync" , "/v1/process/async" , "/v1/process/fetch" ] ;
201
203
202
- // auth middleware to do token based authentication for all user facing routes
203
- // if auth is enabled.
204
+ /// validate the request before passing it to the handler
205
+ pub ( crate ) async fn validate_request ( request : axum:: extract:: Request , next : Next ) -> Response {
206
+ // check if the request id contains "."
207
+ if let Some ( header) = request. headers ( ) . get ( "X-Numaflow-Id" ) {
208
+ // make sure value does not contain "."
209
+ if header
210
+ . to_str ( )
211
+ . expect ( "header should be a string" )
212
+ . contains ( "." )
213
+ {
214
+ return Response :: builder ( )
215
+ . status ( StatusCode :: BAD_REQUEST )
216
+ . body ( Body :: from ( format ! (
217
+ "Header-ID should not contain '.', found {}" ,
218
+ header. to_str( ) . expect( "header should be a string" )
219
+ ) ) )
220
+ . expect ( "failed to build response" ) ;
221
+ }
222
+ } ;
223
+
224
+ next. run ( request) . await
225
+ }
226
+
227
+ /// auth middleware to do token based authentication for all user facing routes if auth is enabled.
204
228
async fn auth_middleware (
205
229
State ( api_auth_token) : State < Option < String > > ,
206
230
request : axum:: extract:: Request ,
@@ -382,6 +406,68 @@ mod tests {
382
406
Ok ( ( ) )
383
407
}
384
408
409
+ #[ cfg( feature = "nats-tests" ) ]
410
+ #[ tokio:: test]
411
+ async fn test_validate_request ( ) -> Result < ( ) > {
412
+ let settings = Settings {
413
+ api_auth_token : Some ( "test-token" . into ( ) ) ,
414
+ ..Default :: default ( )
415
+ } ;
416
+
417
+ let datum_store = InMemoryDataStore :: new ( None ) ;
418
+ let callback_store = InMemoryCallbackStore :: new ( None ) ;
419
+ let store_name = "test_validate_request" ;
420
+ let js_url = "localhost:4222" ;
421
+ let client = async_nats:: connect ( js_url) . await . unwrap ( ) ;
422
+ let context = jetstream:: new ( client) ;
423
+ let _ = context. delete_key_value ( store_name) . await ;
424
+
425
+ let _ = context
426
+ . create_key_value ( jetstream:: kv:: Config {
427
+ bucket : store_name. to_string ( ) ,
428
+ history : 5 ,
429
+ ..Default :: default ( )
430
+ } )
431
+ . await
432
+ . unwrap ( ) ;
433
+
434
+ let status_tracker = StatusTracker :: new ( context. clone ( ) , store_name, "0" , None )
435
+ . await
436
+ . unwrap ( ) ;
437
+
438
+ let pipeline_spec = PIPELINE_SPEC_ENCODED . parse ( ) . unwrap ( ) ;
439
+ let msg_graph = MessageGraph :: from_pipeline ( & pipeline_spec) ?;
440
+ let callback_state =
441
+ CallbackState :: new ( msg_graph, datum_store, callback_store, status_tracker) . await ?;
442
+
443
+ let nats_connection = async_nats:: connect ( "localhost:4222" )
444
+ . await
445
+ . expect ( "Failed to establish Jetstream connection" ) ;
446
+ let js_context = jetstream:: new ( nats_connection) ;
447
+
448
+ let app_state = AppState {
449
+ js_context,
450
+ settings : Arc :: new ( settings) ,
451
+ orchestrator_state : callback_state,
452
+ cancellation_token : CancellationToken :: new ( ) ,
453
+ } ;
454
+
455
+ let router = router_with_auth ( app_state) . await . unwrap ( ) ;
456
+ let res = router
457
+ . oneshot (
458
+ axum:: extract:: Request :: builder ( )
459
+ . method ( "POST" )
460
+ . uri ( "/v1/process/sync" )
461
+ . header ( "X-Numaflow-Id" , "test.id" )
462
+ . body ( Body :: empty ( ) )
463
+ . unwrap ( ) ,
464
+ )
465
+ . await ?;
466
+
467
+ assert_eq ! ( res. status( ) , StatusCode :: BAD_REQUEST ) ;
468
+ Ok ( ( ) )
469
+ }
470
+
385
471
#[ cfg( feature = "nats-tests" ) ]
386
472
#[ tokio:: test]
387
473
async fn test_auth_middleware ( ) -> Result < ( ) > {
0 commit comments