Skip to content

Commit 0730011

Browse files
authored
fix: return 400 if serving ID header contains '.' (#2669)
1 parent cdf118b commit 0730011

File tree

1 file changed

+88
-2
lines changed

1 file changed

+88
-2
lines changed

rust/serving/src/app.rs

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ where
163163
// requests don't hang forever.
164164
TimeoutLayer::new(Duration::from_secs(app.settings.drain_timeout_secs)),
165165
)
166+
// add early validations
167+
.layer(middleware::from_fn(validate_request))
166168
// Add auth middleware to all user facing routes
167169
.layer(middleware::from_fn_with_state(
168170
app.settings.api_auth_token.clone(),
@@ -199,8 +201,30 @@ async fn graceful_shutdown(handle: Handle, duration_secs: u64) {
199201

200202
const PUBLISH_ENDPOINTS: [&str; 3] = ["/v1/process/sync", "/v1/process/async", "/v1/process/fetch"];
201203

202-
// auth middleware to do token based authentication for all user facing routes
203-
// if auth is enabled.
204+
/// validate the request before passing it to the handler
205+
pub(crate) async fn validate_request(request: axum::extract::Request, next: Next) -> Response {
206+
// check if the request id contains "."
207+
if let Some(header) = request.headers().get("X-Numaflow-Id") {
208+
// make sure value does not contain "."
209+
if header
210+
.to_str()
211+
.expect("header should be a string")
212+
.contains(".")
213+
{
214+
return Response::builder()
215+
.status(StatusCode::BAD_REQUEST)
216+
.body(Body::from(format!(
217+
"Header-ID should not contain '.', found {}",
218+
header.to_str().expect("header should be a string")
219+
)))
220+
.expect("failed to build response");
221+
}
222+
};
223+
224+
next.run(request).await
225+
}
226+
227+
/// auth middleware to do token based authentication for all user facing routes if auth is enabled.
204228
async fn auth_middleware(
205229
State(api_auth_token): State<Option<String>>,
206230
request: axum::extract::Request,
@@ -382,6 +406,68 @@ mod tests {
382406
Ok(())
383407
}
384408

409+
#[cfg(feature = "nats-tests")]
410+
#[tokio::test]
411+
async fn test_validate_request() -> Result<()> {
412+
let settings = Settings {
413+
api_auth_token: Some("test-token".into()),
414+
..Default::default()
415+
};
416+
417+
let datum_store = InMemoryDataStore::new(None);
418+
let callback_store = InMemoryCallbackStore::new(None);
419+
let store_name = "test_validate_request";
420+
let js_url = "localhost:4222";
421+
let client = async_nats::connect(js_url).await.unwrap();
422+
let context = jetstream::new(client);
423+
let _ = context.delete_key_value(store_name).await;
424+
425+
let _ = context
426+
.create_key_value(jetstream::kv::Config {
427+
bucket: store_name.to_string(),
428+
history: 5,
429+
..Default::default()
430+
})
431+
.await
432+
.unwrap();
433+
434+
let status_tracker = StatusTracker::new(context.clone(), store_name, "0", None)
435+
.await
436+
.unwrap();
437+
438+
let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap();
439+
let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?;
440+
let callback_state =
441+
CallbackState::new(msg_graph, datum_store, callback_store, status_tracker).await?;
442+
443+
let nats_connection = async_nats::connect("localhost:4222")
444+
.await
445+
.expect("Failed to establish Jetstream connection");
446+
let js_context = jetstream::new(nats_connection);
447+
448+
let app_state = AppState {
449+
js_context,
450+
settings: Arc::new(settings),
451+
orchestrator_state: callback_state,
452+
cancellation_token: CancellationToken::new(),
453+
};
454+
455+
let router = router_with_auth(app_state).await.unwrap();
456+
let res = router
457+
.oneshot(
458+
axum::extract::Request::builder()
459+
.method("POST")
460+
.uri("/v1/process/sync")
461+
.header("X-Numaflow-Id", "test.id")
462+
.body(Body::empty())
463+
.unwrap(),
464+
)
465+
.await?;
466+
467+
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
468+
Ok(())
469+
}
470+
385471
#[cfg(feature = "nats-tests")]
386472
#[tokio::test]
387473
async fn test_auth_middleware() -> Result<()> {

0 commit comments

Comments
 (0)