19
19
from typing import AsyncIterator , Dict , Optional , Set , Tuple , Union
20
20
21
21
import uvloop
22
- from fastapi import APIRouter , FastAPI , HTTPException , Request
22
+ from fastapi import APIRouter , Depends , FastAPI , HTTPException , Request
23
23
from fastapi .exceptions import RequestValidationError
24
24
from fastapi .middleware .cors import CORSMiddleware
25
25
from fastapi .responses import JSONResponse , Response , StreamingResponse
@@ -252,6 +252,15 @@ def _cleanup_ipc_path():
252
252
multiprocess .mark_process_dead (engine_process .pid )
253
253
254
254
255
+ async def validate_json_request (raw_request : Request ):
256
+ content_type = raw_request .headers .get ("content-type" , "" ).lower ()
257
+ if content_type != "application/json" :
258
+ raise HTTPException (
259
+ status_code = HTTPStatus .UNSUPPORTED_MEDIA_TYPE ,
260
+ detail = "Unsupported Media Type: Only 'application/json' is allowed"
261
+ )
262
+
263
+
255
264
router = APIRouter ()
256
265
257
266
@@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response:
335
344
return await health (raw_request )
336
345
337
346
338
- @router .post ("/tokenize" )
347
+ @router .post ("/tokenize" , dependencies = [ Depends ( validate_json_request )] )
339
348
@with_cancellation
340
349
async def tokenize (request : TokenizeRequest , raw_request : Request ):
341
350
handler = tokenization (raw_request )
@@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
350
359
assert_never (generator )
351
360
352
361
353
- @router .post ("/detokenize" )
362
+ @router .post ("/detokenize" , dependencies = [ Depends ( validate_json_request )] )
354
363
@with_cancellation
355
364
async def detokenize (request : DetokenizeRequest , raw_request : Request ):
356
365
handler = tokenization (raw_request )
@@ -379,7 +388,8 @@ async def show_version():
379
388
return JSONResponse (content = ver )
380
389
381
390
382
- @router .post ("/v1/chat/completions" )
391
+ @router .post ("/v1/chat/completions" ,
392
+ dependencies = [Depends (validate_json_request )])
383
393
@with_cancellation
384
394
async def create_chat_completion (request : ChatCompletionRequest ,
385
395
raw_request : Request ):
@@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
400
410
return StreamingResponse (content = generator , media_type = "text/event-stream" )
401
411
402
412
403
- @router .post ("/v1/completions" )
413
+ @router .post ("/v1/completions" , dependencies = [ Depends ( validate_json_request )] )
404
414
@with_cancellation
405
415
async def create_completion (request : CompletionRequest , raw_request : Request ):
406
416
handler = completion (raw_request )
@@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
418
428
return StreamingResponse (content = generator , media_type = "text/event-stream" )
419
429
420
430
421
- @router .post ("/v1/embeddings" )
431
+ @router .post ("/v1/embeddings" , dependencies = [ Depends ( validate_json_request )] )
422
432
@with_cancellation
423
433
async def create_embedding (request : EmbeddingRequest , raw_request : Request ):
424
434
handler = embedding (raw_request )
@@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
464
474
assert_never (generator )
465
475
466
476
467
- @router .post ("/pooling" )
477
+ @router .post ("/pooling" , dependencies = [ Depends ( validate_json_request )] )
468
478
@with_cancellation
469
479
async def create_pooling (request : PoolingRequest , raw_request : Request ):
470
480
handler = pooling (raw_request )
@@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
482
492
assert_never (generator )
483
493
484
494
485
- @router .post ("/score" )
495
+ @router .post ("/score" , dependencies = [ Depends ( validate_json_request )] )
486
496
@with_cancellation
487
497
async def create_score (request : ScoreRequest , raw_request : Request ):
488
498
handler = score (raw_request )
@@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
500
510
assert_never (generator )
501
511
502
512
503
- @router .post ("/v1/score" )
513
+ @router .post ("/v1/score" , dependencies = [ Depends ( validate_json_request )] )
504
514
@with_cancellation
505
515
async def create_score_v1 (request : ScoreRequest , raw_request : Request ):
506
516
logger .warning (
@@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
510
520
return await create_score (request , raw_request )
511
521
512
522
513
- @router .post ("/rerank" )
523
+ @router .post ("/rerank" , dependencies = [ Depends ( validate_json_request )] )
514
524
@with_cancellation
515
525
async def do_rerank (request : RerankRequest , raw_request : Request ):
516
526
handler = rerank (raw_request )
@@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
527
537
assert_never (generator )
528
538
529
539
530
- @router .post ("/v1/rerank" )
540
+ @router .post ("/v1/rerank" , dependencies = [ Depends ( validate_json_request )] )
531
541
@with_cancellation
532
542
async def do_rerank_v1 (request : RerankRequest , raw_request : Request ):
533
543
logger .warning_once (
@@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
538
548
return await do_rerank (request , raw_request )
539
549
540
550
541
- @router .post ("/v2/rerank" )
551
+ @router .post ("/v2/rerank" , dependencies = [ Depends ( validate_json_request )] )
542
552
@with_cancellation
543
553
async def do_rerank_v2 (request : RerankRequest , raw_request : Request ):
544
554
return await do_rerank (request , raw_request )
@@ -582,7 +592,7 @@ async def reset_prefix_cache(raw_request: Request):
582
592
return Response (status_code = 200 )
583
593
584
594
585
- @router .post ("/invocations" )
595
+ @router .post ("/invocations" , dependencies = [ Depends ( validate_json_request )] )
586
596
async def invocations (raw_request : Request ):
587
597
"""
588
598
For SageMaker, routes requests to other handlers based on model `task`.
@@ -632,7 +642,8 @@ async def stop_profile(raw_request: Request):
632
642
"Lora dynamic loading & unloading is enabled in the API server. "
633
643
"This should ONLY be used for local development!" )
634
644
635
- @router .post ("/v1/load_lora_adapter" )
645
+ @router .post ("/v1/load_lora_adapter" ,
646
+ dependencies = [Depends (validate_json_request )])
636
647
async def load_lora_adapter (request : LoadLoraAdapterRequest ,
637
648
raw_request : Request ):
638
649
handler = models (raw_request )
@@ -643,7 +654,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest,
643
654
644
655
return Response (status_code = 200 , content = response )
645
656
646
- @router .post ("/v1/unload_lora_adapter" )
657
+ @router .post ("/v1/unload_lora_adapter" ,
658
+ dependencies = [Depends (validate_json_request )])
647
659
async def unload_lora_adapter (request : UnloadLoraAdapterRequest ,
648
660
raw_request : Request ):
649
661
handler = models (raw_request )
0 commit comments