Skip to content

Commit c9f2481

Browse files
authored
Replace fastrtc with aiortc-based WebRTC implementation (#4610)
1 parent 6dbad4d commit c9f2481

File tree

18 files changed

+798
-652
lines changed

18 files changed

+798
-652
lines changed

backend/app/api/dependencies.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from typing import Annotated
66
from uuid import UUID
77

8-
from fastapi import Depends, HTTPException, status
8+
from fastapi import Depends, HTTPException, Request, status
99

1010
from app.core import Scheduler
1111
from app.services import ActivePipelineService, ConfigurationService, ModelService, PipelineService, SystemService
12+
from app.webrtc.manager import WebRTCManager
1213

1314

1415
def is_valid_uuid(identifier: str) -> bool:
@@ -59,9 +60,9 @@ def get_active_pipeline_service() -> ActivePipelineService:
5960
return ActivePipelineService()
6061

6162

62-
def get_scheduler() -> Scheduler:
63+
def get_scheduler(request: Request) -> Scheduler:
6364
"""Provides the global Scheduler instance."""
64-
return Scheduler()
65+
return request.app.state.scheduler
6566

6667

6768
@lru_cache
@@ -102,3 +103,8 @@ def get_model_service(
102103
return ModelService(
103104
mp_model_reload_event=scheduler.mp_model_reload_event,
104105
)
106+
107+
108+
def get_webrtc_manager(request: Request) -> WebRTCManager:
109+
"""Provides the global WebRTCManager instance from FastAPI application's state."""
110+
return request.app.state.webrtc_manager
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""WebRTC API Endpoints"""
5+
6+
import logging
7+
from typing import Annotated
8+
9+
from fastapi import APIRouter, Depends, status
10+
from fastapi.exceptions import HTTPException
11+
12+
from app.api.dependencies import get_webrtc_manager as get_webrtc
13+
from app.schemas.webrtc import Answer, InputData, Offer
14+
from app.webrtc.manager import WebRTCManager
15+
16+
logger = logging.getLogger(__name__)
17+
router = APIRouter(prefix="/api/webrtc", tags=["WebRTC"])
18+
19+
20+
@router.post(
21+
"/offer",
22+
response_model=Answer,
23+
responses={
24+
status.HTTP_200_OK: {"description": "WebRTC Answer"},
25+
status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Internal Server Error"},
26+
},
27+
)
28+
async def create_webrtc_offer(offer: Offer, webrtc_manager: Annotated[WebRTCManager, Depends(get_webrtc)]) -> Answer:
29+
"""Create a WebRTC offer"""
30+
try:
31+
return await webrtc_manager.handle_offer(offer)
32+
except Exception as e:
33+
logger.error("Error processing WebRTC offer: %s", e)
34+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
35+
36+
37+
@router.post(
38+
"/input_hook",
39+
responses={
40+
status.HTTP_200_OK: {"description": "WebRTC input data updated"},
41+
},
42+
)
43+
async def webrtc_input_hook(data: InputData, webrtc_manager: Annotated[WebRTCManager, Depends(get_webrtc)]) -> None:
44+
"""Update webrtc input with user data"""
45+
webrtc_manager.set_input(data)

backend/app/core/lifecycle.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from app.core.scheduler import Scheduler
1313
from app.db import migration_manager
1414
from app.settings import get_settings
15+
from app.webrtc.manager import WebRTCManager
1516

1617
logger = logging.getLogger(__name__)
1718

1819

1920
@asynccontextmanager
20-
async def lifespan(_: FastAPI) -> AsyncGenerator[None]:
21+
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
2122
"""FastAPI lifespan context manager"""
2223
# Startup
2324
settings = get_settings()
@@ -30,15 +31,17 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None]:
3031

3132
# Initialize Scheduler
3233
app_scheduler = Scheduler()
33-
34-
# Start worker processes
3534
app_scheduler.start_workers()
35+
app.state.scheduler = app_scheduler
3636

37+
webrtc_manager = WebRTCManager(app_scheduler.rtc_stream_queue)
38+
app.state.webrtc_manager = webrtc_manager
3739
logger.info("Application startup completed")
3840

3941
yield
4042

4143
# Shutdown
4244
logger.info("Shutting down %s application...", settings.app_name)
45+
await webrtc_manager.cleanup()
4346
app_scheduler.shutdown()
4447
logger.info("Application shutdown completed")

backend/app/main.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,15 @@
99
# - docker compose -f docker-compose.dev.yaml up
1010

1111
import logging
12-
from collections.abc import Iterator
1312
from pathlib import Path
1413

15-
import anyio
16-
import gradio as gr
17-
import numpy as np
1814
import uvicorn
1915
from fastapi import FastAPI
2016
from fastapi.middleware.cors import CORSMiddleware
21-
from fastapi.responses import HTMLResponse
22-
from fastrtc import AdditionalOutputs, Stream
23-
from pydantic import BaseModel, Field
17+
from fastapi.responses import FileResponse
2418

25-
from app.api.endpoints import models, pipelines, sinks, sources, system
26-
from app.core import Scheduler, lifespan
19+
from app.api.endpoints import models, pipelines, sinks, sources, system, webrtc
20+
from app.core import lifespan
2721
from app.settings import get_settings
2822

2923
settings = get_settings()
@@ -34,25 +28,6 @@
3428
)
3529
logger = logging.getLogger(__name__)
3630

37-
38-
def rtc_stream_routine() -> Iterator[tuple[np.ndarray, AdditionalOutputs]]:
39-
"""Iterator to send frames with predictions to the WebRTC visualization stream"""
40-
scheduler = Scheduler()
41-
while not scheduler.mp_stop_event.is_set():
42-
yield scheduler.rtc_stream_queue.get()
43-
logger.info("Stopped RTC stream routine")
44-
45-
46-
stream = Stream(
47-
handler=rtc_stream_routine,
48-
modality="video",
49-
mode="receive",
50-
additional_outputs=[
51-
gr.Textbox(label="Predictions"),
52-
],
53-
additional_outputs_handler=lambda _c1, pred: pred,
54-
)
55-
5631
app = FastAPI(
5732
title=settings.app_name,
5833
version=settings.version,
@@ -78,28 +53,21 @@ def rtc_stream_routine() -> Iterator[tuple[np.ndarray, AdditionalOutputs]]:
7853
app.include_router(pipelines.router)
7954
app.include_router(models.router)
8055
app.include_router(system.router)
56+
app.include_router(webrtc.router)
8157

8258
cur_dir = Path(__file__).parent
8359

8460

8561
@app.get("/api/docs", include_in_schema=False)
86-
async def get_scalar_docs() -> HTMLResponse:
62+
async def get_scalar_docs() -> FileResponse:
8763
"""Shows docs for our OpenAPI specification using scalar"""
88-
async with await anyio.open_file(cur_dir / "scalar.html") as file:
89-
html_content = await file.read()
90-
return HTMLResponse(content=html_content)
64+
return FileResponse(cur_dir / "static" / "scalar.html")
9165

9266

93-
class InputData(BaseModel):
94-
webrtc_id: str
95-
conf_threshold: float = Field(ge=0, le=1)
96-
97-
98-
# TODO remove this endpoint, make sure the UI does not require it
99-
@app.post("/api/input_hook", tags=["webrtc"])
100-
async def webrtc_input_hook(data: InputData) -> None:
101-
"""Update webrtc input for user"""
102-
stream.set_input(data.webrtc_id, data.conf_threshold)
67+
@app.get("/stream", include_in_schema=False)
68+
async def get_webrtc_stream() -> FileResponse:
69+
"""Get webrtc player"""
70+
return FileResponse(cur_dir / "static" / "webrtc.html")
10371

10472

10573
@app.get("/health")
@@ -108,9 +76,6 @@ async def health_check() -> dict[str, str]:
10876
return {"status": "ok"}
10977

11078

111-
stream.mount(app, "/api")
112-
113-
11479
def main() -> None:
11580
"""Main application entry point"""
11681
logger.info(f"Starting {settings.app_name} in {settings.environment} mode")

backend/app/schemas/webrtc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from pydantic import BaseModel, Field
5+
6+
7+
class InputData(BaseModel):
8+
webrtc_id: str
9+
conf_threshold: float = Field(ge=0, le=1)
10+
11+
12+
class Offer(BaseModel):
13+
webrtc_id: str
14+
sdp: str
15+
type: str
16+
17+
18+
class Answer(BaseModel):
19+
sdp: str
20+
type: str

backend/app/scalar.html renamed to backend/app/static/scalar.html

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
<!--Copyright (C) 2025 Intel Corporation-->
2+
<!--SPDX-License-Identifier: Apache-2.0-->
3+
14
<!doctype html>
25
<html>
36
<head>

0 commit comments

Comments
 (0)