Skip to content

Commit 37dfa60

Browse files
[Bugfix] Missing Content Type returns 500 Internal Server Error (#13193)
1 parent 1bc3b5e commit 37dfa60

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

tests/entrypoints/openai/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
156156
max_tokens=10)
157157

158158
assert len(response.choices) == 1
159+
160+
161+
@pytest.mark.asyncio
162+
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
163+
164+
chat_input = [{"role": "user", "content": "Write a long story"}]
165+
client = server.get_async_client()
166+
167+
with pytest.raises(openai.APIStatusError):
168+
await client.chat.completions.create(
169+
messages=chat_input,
170+
model=MODEL_NAME,
171+
max_tokens=10000,
172+
extra_headers={
173+
"Content-Type": "application/x-www-form-urlencoded"
174+
})

vllm/entrypoints/openai/api_server.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
2020

2121
import uvloop
22-
from fastapi import APIRouter, FastAPI, HTTPException, Request
22+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
2323
from fastapi.exceptions import RequestValidationError
2424
from fastapi.middleware.cors import CORSMiddleware
2525
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -252,6 +252,15 @@ def _cleanup_ipc_path():
252252
multiprocess.mark_process_dead(engine_process.pid)
253253

254254

255+
async def validate_json_request(raw_request: Request):
256+
content_type = raw_request.headers.get("content-type", "").lower()
257+
if content_type != "application/json":
258+
raise HTTPException(
259+
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
260+
detail="Unsupported Media Type: Only 'application/json' is allowed"
261+
)
262+
263+
255264
router = APIRouter()
256265

257266

@@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response:
335344
return await health(raw_request)
336345

337346

338-
@router.post("/tokenize")
347+
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
339348
@with_cancellation
340349
async def tokenize(request: TokenizeRequest, raw_request: Request):
341350
handler = tokenization(raw_request)
@@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
350359
assert_never(generator)
351360

352361

353-
@router.post("/detokenize")
362+
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
354363
@with_cancellation
355364
async def detokenize(request: DetokenizeRequest, raw_request: Request):
356365
handler = tokenization(raw_request)
@@ -379,7 +388,8 @@ async def show_version():
379388
return JSONResponse(content=ver)
380389

381390

382-
@router.post("/v1/chat/completions")
391+
@router.post("/v1/chat/completions",
392+
dependencies=[Depends(validate_json_request)])
383393
@with_cancellation
384394
async def create_chat_completion(request: ChatCompletionRequest,
385395
raw_request: Request):
@@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
400410
return StreamingResponse(content=generator, media_type="text/event-stream")
401411

402412

403-
@router.post("/v1/completions")
413+
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
404414
@with_cancellation
405415
async def create_completion(request: CompletionRequest, raw_request: Request):
406416
handler = completion(raw_request)
@@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
418428
return StreamingResponse(content=generator, media_type="text/event-stream")
419429

420430

421-
@router.post("/v1/embeddings")
431+
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
422432
@with_cancellation
423433
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
424434
handler = embedding(raw_request)
@@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
464474
assert_never(generator)
465475

466476

467-
@router.post("/pooling")
477+
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
468478
@with_cancellation
469479
async def create_pooling(request: PoolingRequest, raw_request: Request):
470480
handler = pooling(raw_request)
@@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
482492
assert_never(generator)
483493

484494

485-
@router.post("/score")
495+
@router.post("/score", dependencies=[Depends(validate_json_request)])
486496
@with_cancellation
487497
async def create_score(request: ScoreRequest, raw_request: Request):
488498
handler = score(raw_request)
@@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
500510
assert_never(generator)
501511

502512

503-
@router.post("/v1/score")
513+
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
504514
@with_cancellation
505515
async def create_score_v1(request: ScoreRequest, raw_request: Request):
506516
logger.warning(
@@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
510520
return await create_score(request, raw_request)
511521

512522

513-
@router.post("/rerank")
523+
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
514524
@with_cancellation
515525
async def do_rerank(request: RerankRequest, raw_request: Request):
516526
handler = rerank(raw_request)
@@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
527537
assert_never(generator)
528538

529539

530-
@router.post("/v1/rerank")
540+
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
531541
@with_cancellation
532542
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
533543
logger.warning_once(
@@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
538548
return await do_rerank(request, raw_request)
539549

540550

541-
@router.post("/v2/rerank")
551+
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
542552
@with_cancellation
543553
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
544554
return await do_rerank(request, raw_request)
@@ -582,7 +592,7 @@ async def reset_prefix_cache(raw_request: Request):
582592
return Response(status_code=200)
583593

584594

585-
@router.post("/invocations")
595+
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
586596
async def invocations(raw_request: Request):
587597
"""
588598
For SageMaker, routes requests to other handlers based on model `task`.
@@ -632,7 +642,8 @@ async def stop_profile(raw_request: Request):
632642
"Lora dynamic loading & unloading is enabled in the API server. "
633643
"This should ONLY be used for local development!")
634644

635-
@router.post("/v1/load_lora_adapter")
645+
@router.post("/v1/load_lora_adapter",
646+
dependencies=[Depends(validate_json_request)])
636647
async def load_lora_adapter(request: LoadLoraAdapterRequest,
637648
raw_request: Request):
638649
handler = models(raw_request)
@@ -643,7 +654,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest,
643654

644655
return Response(status_code=200, content=response)
645656

646-
@router.post("/v1/unload_lora_adapter")
657+
@router.post("/v1/unload_lora_adapter",
658+
dependencies=[Depends(validate_json_request)])
647659
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
648660
raw_request: Request):
649661
handler = models(raw_request)

0 commit comments

Comments
 (0)