diff --git a/docs/book/getting-started/deploying-zenml/deploy-with-docker.md b/docs/book/getting-started/deploying-zenml/deploy-with-docker.md index 622c341a7be..69dca980cb8 100644 --- a/docs/book/getting-started/deploying-zenml/deploy-with-docker.md +++ b/docs/book/getting-started/deploying-zenml/deploy-with-docker.md @@ -256,7 +256,6 @@ The following secure headers environment variables are supported: * **ZENML\_SERVER\_SECURE\_HEADERS\_SERVER**: The `Server` HTTP header value used to identify the server. The default value is the ZenML server ID. * **ZENML\_SERVER\_SECURE\_HEADERS\_HSTS**: The `Strict-Transport-Security` HTTP header value. The default value is `max-age=63072000; includeSubDomains`. * **ZENML\_SERVER\_SECURE\_HEADERS\_XFO**: The `X-Frame-Options` HTTP header value. The default value is `SAMEORIGIN`. -* **ZENML\_SERVER\_SECURE\_HEADERS\_XXP**: The `X-XSS-Protection` HTTP header value. The default value is `0`. NOTE: this header is deprecated and should not be customized anymore. The `Content-Security-Policy` header should be used instead. * **ZENML\_SERVER\_SECURE\_HEADERS\_CONTENT**: The `X-Content-Type-Options` HTTP header value. The default value is `nosniff`. * **ZENML\_SERVER\_SECURE\_HEADERS\_CSP**: The `Content-Security-Policy` HTTP header value. This is by default set to a strict CSP policy that only allows content from the origins required by the ZenML dashboard. NOTE: customizing this header is discouraged, as it may cause the ZenML dashboard to malfunction. * **ZENML\_SERVER\_SECURE\_HEADERS\_REFERRER**: The `Referrer-Policy` HTTP header value. The default value is `no-referrer-when-downgrade`. diff --git a/docs/book/how-to/containerization/containerization.md b/docs/book/how-to/containerization/containerization.md index 0ac1581f1cf..1399aef3333 100644 --- a/docs/book/how-to/containerization/containerization.md +++ b/docs/book/how-to/containerization/containerization.md @@ -354,7 +354,16 @@ you already want this automatic detection in current versions of ZenML, set `dis docker_settings = DockerSettings(install_stack_requirements=False) ``` -7. **Install Local Projects**: +7. **Control Deployment Requirements**: + By default, if you have a Deployer stack component in your active stack, ZenML installs the requirements needed by the deployment application configured in your deployment settings. You can disable this behavior if needed: + + ```python + from zenml.config import DockerSettings + + docker_settings = DockerSettings(install_deployment_requirements=False) + ``` + +8. **Install Local Projects**: If your code requires the installation of some local code files as a python package, you can specify a command that installs it as follows: ```python diff --git a/docs/book/how-to/deployment/deployment.md b/docs/book/how-to/deployment/deployment.md index 511a367fa8f..9b9688b0441 100644 --- a/docs/book/how-to/deployment/deployment.md +++ b/docs/book/how-to/deployment/deployment.md @@ -25,6 +25,41 @@ Pipeline deployments are ideal for scenarios requiring real-time, on-demand exec **Multi-step Business Workflows**: Orchestrate complex processes involving multiple AI/ML components, like document processing pipelines that combine OCR, entity extraction, sentiment analysis, and classification into a single deployable service. +## Traditional Model Serving vs. Deployed Pipelines + +If you're reaching for tools like Seldon or KServe, consider this: deployed +pipelines give you all the core serving primitives, plus the power of a full +application runtime. + +- Equivalent functionality: A pipeline handles the end-to-end inference path + out of the box — request validation, feature pre-processing, model loading + and inference, post-processing, and response shaping. +- More flexible: Deployed pipelines are unopinionated, so you can layer in + retrieval, guardrails, rules, A/B routing, canary logic, human-in-the-loop, + or any custom orchestration. You're not constrained by a model-server template. +- More customizable: The deployment is a real ASGI app. Tailor endpoints, + authentication, authorization, rate limiting, structured logging, tracing, + correlation IDs, or SSO/OIDC — all with first-class middleware and + framework-level hooks. +- More features: Serve single-page apps alongside the API. Ship admin/ops + dashboards, experiment playgrounds, model cards, or customer-facing UIs + from the very same deployment for tighter operational feedback loops. + +This approach aligns better with production realities: inference is rarely +"just call a model." There are policies, data dependencies, and integrations +that need a programmable, evolvable surface. Deployed pipelines give you that +without sacrificing the convenience of a managed deployer and a clean HTTP +contract. + +{% hint style="info" %} +Deprecation notice: ZenML is phasing out the Model Deployer stack components +in favor of pipeline deployments. Pipeline deployments are the strategic +direction for real-time serving: they are more dynamic, more extensible, and +offer deeper integration points with your security, observability, and product +requirements. Existing model deployers will continue to function during the +transition period, but new investments will focus on pipeline deployments. +{% endhint %} + ## How Deployments Work To deploy a pipeline or snapshot, a **Deployer** stack component needs to be in your active stack: @@ -420,6 +455,48 @@ The following happens when the pipeline is deployed and then later invoked: This mechanism can be used to initialize and share global state between all the HTTP requests made to the deployment or to execute long-running initialization or cleanup operations when the deployment is started or stopped rather than on each HTTP request. +## Deployment Configuration + +The deployer settings cover aspects of the pipeline deployment process and specific back-end infrastructure used to provision and manage the resources required to run the deployment servers. Independently of that, `DeploymentSettings` can be used to fully customize all aspects pertaining to the deployment ASGI application itself, including: + +* HTTP endpoints +* middleware +* secure headers +* CORS settings +* mounting and serving static files to support deploying single-page applications alongside the pipeline +* for more advanced cases, even the ASGI framework (e.g. FastAPI, Django, Flask, Falcon, Quart, BlackSheep, etc.) and its configuration can be customized + +Example: + +```python +from zenml.config import DeploymentSettings, EndpointSpec, EndpointMethod +from zenml import pipeline + +async def custom_health_check() -> Dict[str, Any]: + from zenml.client import Client + + client = Client() + return { + "status": "healthy", + "info": client.zen_store.get_store_info().model_dump(), + } + +@pipeline(settings={"deployment": DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/health", + method=EndpointMethod.GET, + handler=custom_health_check, + auth_required=False, + ), + ], +)}) +def my_pipeline(): + ... +``` + +For more detailed information on deployment options, see the [deployment settings guide](./deployment_settings.md). + ## Best Practices 1. **Design for Parameters**: Structure your pipelines to accept meaningful parameters that control behavior diff --git a/docs/book/how-to/deployment/deployment_settings.md b/docs/book/how-to/deployment/deployment_settings.md new file mode 100644 index 00000000000..4aed6180e67 --- /dev/null +++ b/docs/book/how-to/deployment/deployment_settings.md @@ -0,0 +1,1051 @@ +--- +description: Customize the pipeline deployment ASGI application with DeploymentSettings. +--- + + +## Deployment servers and ASGI apps + +ZenML pipeline deployments run an ASGI application under a production-grade +`uvicorn` server. This makes your pipelines callable over HTTP for online +workloads like real-time ML inference, LLM agents/workflows, and even full +web apps co-located with pipelines. + +At runtime, three core components work together: + +- the ASGI application: the HTTP surface that exposes endpoints (health, invoke, + metrics, docs) and any custom routes or middleware you configure. This is powered by an ASGI framework like FastAPI, Starlette, Django, Flask, etc. +- the ASGI application factory (aka the Deployment App Runner): this component is responsible for constructing the ASGI application piece by piece based on the instructions provided by users via runtime configuration. +- the Deployment Service: the component responsible for the business logic that + backs the pipeline deployment and its invocation lifecycle. + +Both the Deployment App Runner and the Deployment Service are customizable at runtime, through the `DeploymentSettings` configuration mechanism. They can also be extended via inheritance to support different ASGI frameworks or to tweak existing functionality. + +The `DeploymentSettings` class lets you shape both server behavior and the +ASGI app composition without changing framework code. Typical reasons to +customize include: + +- Tight security posture: CORS controls, strict headers, authentication, + API surface minimization. +- Observability: request/response logging, tracing, metrics, correlation + identifiers. +- Enterprise integration: policy gateways, SSO/OIDC/OAuth, audit logging, + routing and network architecture constraints. +- Product UX: single-page application (SPA) static files served alongside + deployment APIs or custom docs paths. +- Performance/SRE: thread pool sizing, uvicorn worker settings, log levels, + max request sizes and platform-specific fine-tuning. + +All `DeploymentSettings` are pipeline-level settings. They apply to the +deployment that serves the pipeline as a whole. They are not available at +step-level. + +## Configuration overview + +You can configure `DeploymentSettings` in Python or via YAML, the same way as +other settings classes. The settings can be attached to a pipeline decorator +or via `with_options`. These settings are only valid at pipeline level. + +### Python configuration + +Use the `DeploymentSettings` class to configure the deployment settings for your +pipeline in-code + +```python +from zenml import pipeline +from zenml.config import DeploymentSettings + +deploy_settings = DeploymentSettings( + app_title="Fraud Scoring Service", + app_description=( + "Online scoring API exposing synchronous and batch inference" + ), + app_version="1.2.0", + root_url_path="", + api_url_path="", + docs_url_path="/docs", + redoc_url_path="/redoc", + invoke_url_path="/invoke", + health_url_path="/health", + info_url_path="/info", + metrics_url_path="/metrics", + cors={ + "allow_origins": ["https://app.example.com"], + "allow_methods": ["GET", "POST", "OPTIONS"], + "allow_headers": ["*"], + "allow_credentials": True, + }, + thread_pool_size=32, + uvicorn_host="0.0.0.0", + uvicorn_port=8080, + uvicorn_workers=2, +) + +@pipeline(settings={"deployment": deploy_settings}) +def scoring_pipeline() -> None: + ... + +# Alternatively +scoring_pipeline = scoring_pipeline.with_options( + settings={"deployment": deploy_settings} +) +``` + +### YAML configuration + +Define settings in a YAML configuration file for better separation of code and configuration: + +```yaml +settings: + deployment: + app_title: Fraud Scoring Service + app_description: >- + Online scoring API exposing synchronous and batch inference + app_version: "1.2.0" + root_url_path: "" + api_url_path: "" + docs_url_path: "/docs" + redoc_url_path: "/redoc" + invoke_url_path: "/invoke" + health_url_path: "/health" + info_url_path: "/info" + metrics_url_path: "/metrics" + cors: + allow_origins: ["https://app.example.com"] + allow_methods: ["GET", "POST", "OPTIONS"] + allow_headers: ["*"] + allow_credentials: true + thread_pool_size: 32 + uvicorn_host: 0.0.0.0 + uvicorn_port: 8080 + uvicorn_workers: 2 +``` + +Check out [this page](https://docs.zenml.io/concepts/steps_and_pipelines/configuration) for more information on the hierarchy and precedence of the various ways in which you can supply the settings. + +## Basic customization options + +`DeploymentSettings` expose the following basic customization options. The sections below provide +short examples and guidance. + +- application metadata and paths +- built-in endpoints and middleware toggles +- static files (SPAs) and dashboards +- CORS +- secure headers +- startup and shutdown hooks +- uvicorn server options, logging level, and thread pool size + +### Application metadata + +You can set `app_title`, `app_description`, and `app_version` to be reflected in the ASGI application's metadata: + +```python +from zenml.config import DeploymentSettings + +settings = DeploymentSettings( + app_title="LLM Agent Service", + app_description=( + "Agent endpoints for tools, state inspection, and tracing" + ), + app_version="0.7.0", +) +``` + +### Default URL paths, endpoints and middleware + +The ASGI application exposes the following built-in endpoints by default: + +* documentation endpoints: + * `/docs` - The OpenAPI documentation UI generated based on the endpoints and their signatures. + * `/redoc` - The ReDoc documentation UI generated based on the endpoints and their signatures. +* REST API endpoints: + * `/invoke` - The main pipeline invocation endpoint for synchronous inference. + * `/health` - The health check endpoint. + * `/info` - The info endpoint providing extensive information about the deployment and its service. + * `/metrics` - Simple metrics endpoint. +* dashboard endpoints - present only if the accompanying UI is enabled: + * `/`, `/index.html`, `/static` - Endpoints for serving the dashboard files from the `dashboard_files_path` directory. + +The ASGI application includes the following built-in middleware by default: +* secure headers middleware: for setting security headers. +* CORS middleware: for handling CORS requests. + +You can include or exclude these default endpoints and middleware either globally or individually by setting the `include_default_endpoints` and `include_default_middleware` settings. It is also possible to remap the built-in endpoint URL paths. + +```python +from zenml.config import ( + DeploymentSettings, + DeploymentDefaultEndpoints, + DeploymentDefaultMiddleware, +) + +settings = DeploymentSettings( + # Include only the endpoints you need + include_default_endpoints=( + DeploymentDefaultEndpoints.DOCS + | DeploymentDefaultEndpoints.INVOKE + | DeploymentDefaultEndpoints.HEALTH + ), + # Customize the root URL path + root_url_path="/pipeline", + # Include only the middleware you need + include_default_middleware=DeploymentDefaultMiddleware.CORS, + # Customize the base API URL path used for all REST API endpoints + api_url_path="/api", + # Customize the documentation URL path + docs_url_path="/documentation", + # Customize the health check URL path + health_url_path="/healthz", +) +``` + +With the above settings, the ASGI application will only expose the following endpoints and middleware: + +- `/pipeline/documentation` - The API documentation (OpenAPI schema) +- `/pipeline/api/invoke` - The REST API pipeline invocation endpoint +- `/pipeline/api/healthz` - The REST API health check endpoint +- CORS middleware: for handling CORS requests + +### Static files (single-page applications) + +Deployed pipelines can serve full single-page applications (React/Vue/Svelte) +from the same origin as your inference API. This eliminates CORS/auth/routing +friction and lets you ship user-facing UI components alongside +your endpoints, such as: + +* operator dashboards +* governance portals +* experiment browsers +* feature explorers +* custom data labeling interfaces +* model cards +* observability dashboards +* customer-facing playgrounds + +Co-locating UI and API streamlines delivery (one image, one URL, one CI/CD), +improves latency, and keeps telemetry and auth consistent. + +To enable this, point `dashboard_files_path` to a directory containing an +`index.html` and any static assets. The path must be relative to the +[source root](../steps-pipelines/sources.md#source-root): + +```python +settings = DeploymentSettings( + dashboard_files_path="web/build" # contains index.html and assets/ +) +``` + +A rudimentary playground dashboard is included with the ZenML python package that features a simple UI useful for sending pipeline invocations and viewing the pipeline's response. + +{% hint style="info" %} +When supplying your own custom dashboard, you may also need to [customize the security headers](./deployment_settings#secure-headers) to allow the dashboard to access various resources. For example, you may want to tweak the `Content-Security-Policy` header to allow the dashboard to access external javascript libraries, images, etc. +{% endhint %} + +### CORS + +Fine-tune cross-origin access: + +```python +from zenml.config import DeploymentSettings, CORSConfig + +settings = DeploymentSettings( + cors=CORSConfig( + allow_origins=["https://app.example.com", "https://admin.example.com"], + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["authorization", "content-type", "x-request-id"], + allow_credentials=True, + ) +) +``` + +### Secure headers + +Harden responses with strict headers. Each field supports either a boolean or +string. Using `True` selects a safe default, `False` disables the header, and +custom strings allow fully custom policies: + +```python +from zenml.config import ( + DeploymentSettings, + SecureHeadersConfig, +) + +settings = DeploymentSettings( + secure_headers=SecureHeadersConfig( + server=True, # emit default ZenML server header value + hsts=True, # default: 63072000; includeSubdomains + xfo=True, # default: SAMEORIGIN + content=True, # default: nosniff + csp=( + "default-src 'none'; connect-src 'self' https://api.example.com; " + "img-src 'self' data:; style-src 'self' 'unsafe-inline'" + ), + referrer=True, + cache=True, + permissions=True, + ) +) +``` + +Set any field to `False` to omit that header. Set to a string for a custom +value. The defaults are strong, production-safe policies. + + +### Startup and shutdown hooks + +Lifecycle startup and shutdown hooks are called as part of the ASGI application's lifespan. This is an alternative to [the `on_init` and `on_cleanup` hooks that can be configured at pipeline level](./deployment.md#deployment-initialization-cleanup-and-state). + +Common use-cases: + +- Model inference + - load models/tokenizers and warm caches (JIT/ONNX/TensorRT, HF, sklearn) + - hydrate feature stores, connect to vector DBs (FAISS, Milvus, PGVector) + - initialize GPU memory pools and thread/process pools + - set global config, download artifacts from registry or object store + - prefetch embeddings, label maps, lookup tables + - create connection pools for databases, Redis, Kafka, SQS, Pub/Sub + +- LLM agent workflows + - initialize LLM client(s), tool registry, and router/policy engine + - build or load RAG indexes; warm retrieval caches and prompts + - configure rate limiting, concurrency guards, circuit breakers + - load guardrails (PII filters, toxicity, jailbreak detection) + - configure tracing/observability for token usage and tool calls + +- Shutdown + - flush metrics/traces/logs, close pools/clients, persist state/caches + - graceful draining: wait for in-flight requests before teardown + +Hooks can be provided as: + +- A Python callable object +- A source path string to be loaded dynamically (e.g. `my_project.runtime.hooks.on_startup`) + +The callable must accept an `app_runner` argument of type `BaseDeploymentAppRunner` and any additional keyword arguments. The `app_runner` argument is the application factory that is responsible for building the ASGI application. You can use it to access information such as: + +* the ASGI application instance that is being built +* the deployment service instance that is being deployed +* the `DeploymentResponse` object itself, which also contains details about the snapshot, pipeline, etc. + +```python +from zenml.deployers.server import BaseDeploymentAppRunner + +def on_startup(app_runner: BaseDeploymentAppRunner, warm: bool = False) -> None: + # e.g., warm model cache, connect tracer, prefetch embeddings + ... + +def on_shutdown(app_runner: BaseDeploymentAppRunner, drain_timeout_s: int = 2) -> None: + # e.g., flush metrics, close clients + ... + +settings = DeploymentSettings( + startup_hook=on_startup, + shutdown_hook=on_shutdown, + startup_hook_kwargs={"warm": True}, + shutdown_hook_kwargs={"drain_timeout_s": 2}, +) +``` + +YAML using source strings: + +```yaml +settings: + deployment: + startup_hook: my_project.runtime.hooks.on_startup + shutdown_hook: my_project.runtime.hooks.on_shutdown + startup_hook_kwargs: + warm: true + shutdown_hook_kwargs: + drain_timeout_s: 2 +``` + +### Uvicorn and threading + +Tune server runtime parameters for performance and topology: + +```python +from zenml.config import DeploymentSettings +from zenml.enums import LoggingLevels + +settings = DeploymentSettings( + thread_pool_size=64, # CPU-bound work offload + uvicorn_host="0.0.0.0", + uvicorn_port=8000, + uvicorn_workers=2, # multi-process model + log_level=LoggingLevels.INFO, + uvicorn_kwargs={ + "proxy_headers": True, + "forwarded_allow_ips": "*", + "timeout_keep_alive": 15, + }, +) +``` + +## Advanced customization options + +When the built-in ASGI application, endpoints and middleware are not enough, you can take customizing your deployment to the next level by providing your own implementation for endpoints, middleware and other ASGI application extensions. ZenML `DeploymentSettings` provides a flexible and extensible mechanism to inject your own custom code into the ASGI application at runtime: + +- custom endpoints - to expose your own HTTP endpoints. +- custom middleware - to insert your own ASGI middleware. +- free-form ASGI application building extensions - to take full control of the ASGI application and its lifecycle for truly advanced use-cases when endpoints and middleware are not enough. + +### Custom endpoints + +In production, custom endpoints are often required alongside the main +pipeline invoke route. Common use-cases include: + +- Online inference controls + - model (re)load, warm-up, and cache priming + - dynamic model/version switching and traffic shaping (A/B, canary) + - async/batch prediction submission and job-status polling + - feature store materialization/backfills and online/offline sync triggers + +- Enterprise integration + - authentication bootstrap (API key issuance/rotation), JWKS rotation + - OIDC/OAuth device-code flows and SSO callback handlers + - external system webhooks (CRM, billing, ticketing, audit sink) + +- Observability and operations + - detailed health/readiness endpoints (subsystems, dependencies) + - metrics/traces/log shipping toggles; log level switch (INFO/DEBUG) + - maintenance-mode enable/disable and graceful drain controls + +- LLM agent serving + - tool registry CRUD, tool execution sandboxes, guardrail toggles + - RAG index CRUD (upsert documents, rebuild embeddings, vacuum/compact) + - prompt template catalogs and runtime overrides + - session memory inspection/reset, conversation export/import + +- Governance and data management + - payload redaction policy updates and capture sampling controls + - schema/contract discovery (sample payloads, test vectors) + - tenant provisioning, quotas/limits, and per-tenant configuration + +You can configure `custom_endpoints` in `DeploymentSettings` to expose your own HTTP endpoints. + +Endpoints support multiple definition modes (see code examples below): + +1) Direct callable - a simple function that takes in request parameters and returns a response. Framework-specific arguments such as FastAPI's `Request`, `Response` and dependency injection patterns are supported. +2) Builder class - a callable class with a `__call__` method that is the actual endpoint callable described at 1). The builder class constructor is called by the ASGI application factory and can be leveraged to execute any global initialization logic before the endpoint is called. +3) Builder function - a function that returns the actual endpoint callable described at 1). Similar to the builder class. +4) Native framework-specific object (`native=True`). This can vary from ASGI framework to framework. + +Definitions can be provided as Python objects or as loadable source path strings. + +The builder class and builder function must accept an `app_runner` argument of type `BaseDeploymentAppRunner`. This is the application factory that is responsible for building the ASGI application. You can use it to access information such as: + +* the ASGI application instance that is being built +* the deployment service instance that is being deployed +* the `DeploymentResponse` object itself, which also contains details about the snapshot, pipeline, etc. + +The final endpoint callable can take any input arguments and return any output that are JSON-serializable or Pydantic models. The application factory will handle converting these into the appropriate schema for the ASGI application. + +You can also use framework-specific request/response types (e.g. FastAPI `Request`, `Response`) or dependency injection patterns for your endpoint callable if needed. However, this will limit the portability of your endpoint to other frameworks. + +The following code examples demonstrate the different definition modes for custom endpoints: + +1. a custom detailed health check endpoint implemented as a direct callable + +```python +from typing import Any, Callable, Dict, List +from pydantic import BaseModel +from zenml.client import Client +from zenml.config import ( + DeploymentSettings, + EndpointSpec, + EndpointMethod, +) +from zenml.deployers.server import BaseDeploymentAppRunner +from zenml.models import DeploymentResponse + +async def health_detailed() -> Dict[str, Any]: + import psutil + + client = Client() + + return { + "status": "healthy", + "cpu_percent": psutil.cpu_percent(), + "memory_percent": psutil.virtual_memory().percent, + "disk_percent": psutil.disk_usage("/").percent, + "zenml": client.zen_store.get_store_info().model_dump(), + } + +settings = DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/health", + method=EndpointMethod.GET, + handler=health_detailed, + auth_required=False, + ), + ] +) +``` + +2. a custom ML model inference endpoint, implemented as a builder function. Note how the builder function loads the model only once at runtime, and then reuses it for all subsequent requests. + + +```python +from typing import Any, Callable, Dict, List +from pydantic import BaseModel +from zenml.client import Client +from zenml.config import ( + DeploymentSettings, + EndpointSpec, + EndpointMethod, +) +from zenml.deployers.server import BaseDeploymentAppRunner +from zenml.models import DeploymentResponse + +class PredictionRequest(BaseModel): + features: List[float] + +class PredictionResponse(BaseModel): + prediction: float + confidence: float + +def build_predict_endpoint( + app_runner: BaseDeploymentAppRunner, + model_name: str, + model_version: str, + model_artifact: str, +) -> Callable[[PredictionRequest], PredictionResponse]: + + stored_model_version = Client().get_model_version(model_name, model_version) + stored_model_artifact = stored_model_version.get_artifact(model_artifact) + model = stored_model_artifact.load() + + def predict( + request: PredictionRequest, + ) -> PredictionResponse: + pred = float(model.predict([request.features])[0]) + # Example: return fixed confidence if model lacks proba + return PredictionResponse(prediction=pred, confidence=0.9) + + return predict + +settings = DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/predict/custom", + method=EndpointMethod.POST, + handler=build_predict_endpoint, + init_kwargs={ + "model_name": "fraud-classifier", + "model_version": "v1", + "model_artifact": "sklearn_model", + }, + auth_required=True, + ), + ] +) +``` + +NOTE: a similar way to do this is to implement a proper ZenML pipeline that loads the model in the `on_init` hook and then runs pre-processing and inference steps in the pipeline. + +3. a custom deployment info endpoint implemented as a builder class + + +```python +from typing import Any, Awaitable, Callable, Dict, List +from pydantic import BaseModel +from zenml.client import Client +from zenml.config import ( + DeploymentSettings, + EndpointSpec, + EndpointMethod, +) +from zenml.deployers.server import BaseDeploymentAppRunner +from zenml.models import DeploymentResponse + +def build_deployment_info(app_runner: BaseDeploymentAppRunner) -> Callable[[], Awaitable[DeploymentResponse]]: + async def endpoint() -> DeploymentResponse: + return app_runner.deployment + + return endpoint + +settings = DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/deployment", + method=EndpointMethod.GET, + handler=build_deployment_info, + auth_required=True, + ), + ] +) +``` + +4. a custom model selection endpoint, implemented as a FastAPI router. This example is more involved and demonstrates how to coordinate multiple endpoints with the main pipeline invoke endpoint. + +```python +# my_project.fastapi_endpoints +from __future__ import annotations + +from typing import List, Optional + +from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel, Field +from sklearn.base import ClassifierMixin +from zenml.client import Client +from zenml.models import ArtifactVersionResponse +from zenml.config import DeploymentSettings, EndpointSpec, EndpointMethod + +model_router = APIRouter() + +# Global, process-local model registry for inference +CURRENT_MODEL: Optional[Any] = None +CURRENT_MODEL_ARTIFACT: Optional[ArtifactVersionResponse] = None + + +class LoadModelRequest(BaseModel): + """Request to load/replace the in-memory model version.""" + + model_name: str = Field(default="fraud-classifier") + version_name: str = Field(default="v1") + artifact_name: str = Field(default="sklearn_model") + + +@model_router.post("/load", response_model=ArtifactVersionResponse) +def load_model(req: LoadModelRequest) -> ArtifactVersionResponse: + """Load or replace the in-memory model version.""" + global CURRENT_MODEL, CURRENT_MODEL_ARTIFACT + + model_version = Client().get_model_version( + req.model_name, req.version_name + ) + CURRENT_MODEL_ARTIFACT = model_version.get_artifact(req.artifact_name) + CURRENT_MODEL = CURRENT_MODEL_ARTIFACT.load() + + return CURRENT_MODEL_ARTIFACT + + +@model_router.get("/current", response_model=ArtifactVersionResponse) +def current_model() -> ArtifactVersionResponse: + """Return the artifact of the currently loaded in-memory model.""" + + if CURRENT_MODEL_ARTIFACT is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No model loaded. Use /model/load first.", + ) + + return CURRENT_MODEL_ARTIFACT + +deploy_settings = DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/model", + method=EndpointMethod.POST, # method is ignored for native routers + handler="my_project.fastapi_endpoints.model_router", + native=True, + auth_required=True, + ) + ] +) +``` + +And here is a minimal ZenML inference pipeline that uses the globally loaded +model. The prediction step reads the model from the global variable set +by the FastAPI router above. You can invoke this pipeline via the built-in +`/invoke` endpoint once a model has been loaded through `/model/load`. + +```python +from typing import List + +from pydantic import BaseModel +from zenml import pipeline, step + + +class InferenceRequest(BaseModel): + features: List[float] + + +class InferenceResponse(BaseModel): + prediction: float + + +@step +def preprocess_step(request: InferenceRequest) -> List[float]: + # Replace with real transformations, scaling, encoding, etc. + return request.features + +@step +def predict_step(features: List[float]) -> InferenceResponse: + """Run model inference using the globally loaded model.""" + + if GLOBAL_CURRENT_MODEL is None: + raise RuntimeError( + "No model loaded. Call /model/load before invoking." + ) + + pred = float(GLOBAL_CURRENT_MODEL.predict([features])[0]) + return InferenceResponse(prediction=pred) + + +@pipeline(settings={"deployment": deploy_settings}) +def inference_pipeline(request: InferenceRequest) -> InferenceResponse: + processed = preprocess_step(request) + return predict_step(processed) +``` + +### Custom middleware + +Middleware is where you enforce cross-cutting concerns consistently across every endpoint. Common use-cases include: + +- Security and access control + - API key/JWT verification, tenant extraction and context injection + - IP allow/deny lists, basic WAF-style request filtering, mTLS header checks + - Request body/schema validation and max body size enforcement + +- Governance and privacy + - PII detection/redaction on inputs/outputs; payload sampling/scrubbing + - Policy enforcement (data residency, retention, consent) at request time + +- Reliability and traffic shaping + - Rate limiting, quotas, per-tenant concurrency limits + - Idempotency keys, deduplication, retries with backoff, circuit breakers + - Timeouts, slow-request detection, maintenance mode and graceful drain + +- Observability + - Correlation/trace IDs, OpenTelemetry spans, structured logging + - Metrics for latency, throughput, error rates, request/response sizes + +- Performance and caching + - Response caching/ETags, compression (gzip/br), streaming/chunked responses + - Adaptive content negotiation and serialization tuning + +- LLM/agent-specific controls + - Token accounting/limits, cost guards per tenant/user + - Guardrails (toxicity/PII/jailbreak) and output filtering + - Tool execution sandboxing gates and allowlists + +- Data and feature enrichment + - Feature store prefetch, user/tenant profile enrichment, AB bucketing tags + + +You can configure `custom_middlewares` in `DeploymentSettings` to insert your own ASGI middleware. + +Middlewares support multiple definition modes (see code examples below): + +1) Middleware class - a standard ASGI middleware class that implements the `__call__` method that takes the traditional `scope`, `receive` and `send` arguments. The constructor must accept an `app` argument of type `ASGIApplication` and any additional keyword arguments. +2) Middleware callable - a callable that takes all arguments in one go: `app`, `scope`, `receive` and `send`. +3) Native framework-specific middleware (`native=True`) - this can vary from ASGI framework to framework. + +Definitions can be provided as Python objects or as loadable source path strings. The `order` parameter controls the insertion order in the middleware chain. Lower `order` values insert the middleware earlier in the chain. + +The following code examples demonstrate the different definition modes for custom middlewares: + +1. a custom middleware that adds a processing time header to every response, implemented as a middleware class: + +```python +import time +from typing import Any +from asgiref.compatibility import guarantee_single_callable +from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + ASGISendEvent, + Scope, +) +from zenml.config import DeploymentSettings, MiddlewareSpec + +class RequestTimingMiddleware: + """ASGI middleware to measure request processing time.""" + + def __init__(self, app: ASGIApplication, header_name: str = "x-process-time-ms") -> None: + self.app = guarantee_single_callable(app) + self.header_name = header_name + + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + start_time = time.time() + + async def send_wrapper(message: ASGISendEvent) -> None: + if message["type"] == "http.response.start": + process_time = (time.time() - start_time) * 1000 + headers = list(message.get("headers", [])) + headers.append((self.header_name.encode(), str(process_time).encode())) + message = {**message, "headers": headers} + + await send(message) + + await self.app(scope, receive, send_wrapper) + + +settings = DeploymentSettings( + custom_middlewares=[ + MiddlewareSpec( + middleware=RequestTimingMiddleware, + order=10, + init_kwargs={"header_name": "x-process-time-ms"}, + ), + ] +) +``` + +2. a custom middleware that injects a correlation ID into responses (and generates one if missing), implemented as a middleware callable: + +```python +import uuid +from typing import Any +from asgiref.compatibility import guarantee_single_callable +from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + ASGISendEvent, + Scope, +) +from zenml.config import DeploymentSettings, MiddlewareSpec + +async def request_id_middleware( + app: ASGIApplication, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + header_name: str = "x-request-id", +) -> None: + """ASGI function middleware that ensures a correlation ID header exists.""" + + app = guarantee_single_callable(app) + + if scope["type"] != "http": + await app(scope, receive, send) + return + + # Reuse existing request ID if present; otherwise generate one + request_id = None + for k, v in scope.get("headers", []): + if k.decode().lower() == header_name: + request_id = v.decode() + break + + if not request_id: + request_id = str(uuid.uuid4()) + + async def send_wrapper(message: ASGISendEvent) -> None: + if message["type"] == "http.response.start": + headers = list(message.get("headers", [])) + headers.append((header_name.encode(), request_id.encode())) + message = {**message, "headers": headers} + + await send(message) + + await app(scope, receive, send_wrapper) + + +settings = DeploymentSettings( + custom_middlewares=[ + MiddlewareSpec( + middleware=request_id_middleware, + order=5, + init_kwargs={"header_name": "x-request-id"}, + ), + ] +) +``` + +4. a FastAPI/Starlette-native middleware that adds GZIP support, implemented as a native middleware: + +```python +from starlette.middleware.gzip import GZipMiddleware +from zenml.config import DeploymentSettings, MiddlewareSpec + +settings = DeploymentSettings( + custom_middlewares=[ + MiddlewareSpec( + middleware=GZipMiddleware, + native=True, + order=20, + extra_kwargs={"minimum_size": 1024}, + ), + ] +) +``` + +### App extensions + +App extensions are pluggable components that are running as part of the ASGI application factory that can install complex, possibly framework-specific structures. The following are usual scenarios for using a full-blown extension instead of endpoints/middleware: + +- Advanced authentication and authorization + - install org-wide dependencies (e.g., OAuth/OIDC auth, RBAC guards) + - register custom exception handlers for uniform error envelopes + - augment OpenAPI with security schemes and per-route security policies + +- Multi-tenant and routing topology + - programmatically include routers per tenant/region/version + - mount sub-apps for internal admin vs public APIs under different prefixes + - dynamic route rewrites/switches for blue/green or canary rollouts + +- Observability and platform integration + - wire OpenTelemetry instrumentation at the app level (tracer/meter providers) + - register global request/response logging with redaction policies + - expose or mount vendor-specific observability apps (e.g., Prometheus) + +- LLM agent control plane + - attach a tool registry/router and lifecycle hooks for tools + - register guardrail handlers and policy engines across routes + - install runtime prompt/template catalogs and index management routers + +- API ergonomics and governance + - reshape OpenAPI (tags, servers, components) and versioned docs + - global response model wrapping, pagination conventions, error mappers + - maintenance-mode switch and graceful-drain controls at the app level + +App extensions support multiple definition modes (see code examples below): + +1) Extension class - a class that implements the `BaseAppExtension` abstract class. The class constructor must accept any keyword arguments and the `install` method must accept an `app_runner` argument of type `BaseDeploymentAppRunner`. +2) Extension callable - a callable that takes the `app_runner` argument of type `BaseDeploymentAppRunner`. + +Both classes and callables must take in an `app_runner` argument of type `BaseDeploymentAppRunner`. This is the application factory that is responsible for building the ASGI application. You can use it to access information such as: + +* the ASGI application instance that is being built +* the deployment service instance that is being deployed +* the `DeploymentResponse` object itself, which also contains details about the snapshot, pipeline, etc. + +Definitions can be provided as Python objects or as loadable source path strings. + +The extensions are summoned to take part in the ASGI application building process near the end of the initialization - after the ASGI app has been built according to the deployment configuration settings. + +The example below installs API key authentication at the FastAPI application +level, attaches the dependency to selected routes, registers an auth error +handler, and augments the OpenAPI schema with the security scheme. + +```python +from __future__ import annotations + +from typing import Literal, Sequence, Set + +from fastapi import FastAPI, HTTPException, Request, status +from fastapi.openapi.utils import get_openapi +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader + +from zenml.config import AppExtensionSpec, DeploymentSettings +from zenml.deployers.server.app import BaseDeploymentAppRunner +from zenml.deployers.server.extensions import BaseAppExtension + + +class FastAPIAuthExtension(BaseAppExtension): + """Install API key auth and OpenAPI security on a FastAPI app.""" + + def __init__( + self, + scheme: Literal["api_key"] = "api_key", + header_name: str = "x-api-key", + valid_keys: Sequence[str] | None = None, + ) -> None: + self.scheme = scheme + self.header_name = header_name + self.valid_keys: Set[str] = set(valid_keys or []) + + def install(self, app_runner: BaseDeploymentAppRunner) -> None: + app = app_runner.asgi_app + if not isinstance(app, FastAPI): + raise RuntimeError("FastAPIAuthExtension requires FastAPI") + + api_key_header = APIKeyHeader( + name=self.header_name, auto_error=True + ) + + # Find endpoints that have auth_required=True + protected_endpoints = [ + endpoint.path + for endpoint in app_runner.endpoints + if endpoint.auth_required + ] + + @app.middleware("http") + async def api_key_guard(request: Request, call_next): + if request.url.path in protected_endpoints: + api_key = await api_key_header(request) + if api_key not in self.valid_keys: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API key", + ) + return await call_next(request) + + # Auth error handler + @app.exception_handler(HTTPException) + async def auth_exception_handler( + _, exc: HTTPException + ) -> JSONResponse: + if exc.status_code == status.HTTP_401_UNAUTHORIZED: + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + headers={"WWW-Authenticate": "ApiKey"}, + ) + return JSONResponse( + status_code=exc.status_code, content={"detail": exc.detail} + ) + + # OpenAPI security + def custom_openapi() -> dict: + if app.openapi_schema: + return app.openapi_schema # type: ignore[return-value] + openapi_schema = get_openapi( + title=app.title, + version=app.version if app.version else "0.1.0", + description=app.description, + routes=app.routes, + ) + components = openapi_schema.setdefault("components", {}) + security_schemes = components.setdefault("securitySchemes", {}) + security_schemes["ApiKeyAuth"] = { + "type": "apiKey", + "in": "header", + "name": self.header_name, + } + openapi_schema["security"] = [{"ApiKeyAuth": []}] + app.openapi_schema = openapi_schema + return openapi_schema + + app.openapi = custom_openapi # type: ignore[assignment] + + +settings = DeploymentSettings( + app_extensions=[ + AppExtensionSpec( + extension=( + "my_project.extensions.FastAPIAuthExtension" + ), + extension_kwargs={ + "scheme": "api_key", + "header_name": "x-api-key", + "valid_keys": ["secret-1", "secret-2"], + }, + ) + ] +) +``` + +## Implementation customizations for advanced use cases + +For cases where you need deeper control over how the ASGI app is created or +how the deployment logic is implemented, you can swap/extend the core +components using the following `DeploymentSettings` fields: + +- `deployment_app_runner_flavor` and `deployment_app_runner_kwargs` let you + choose or extend the app runner that constructs and runs the ASGI app. This + needs to be set to a subclass of `BaseDeploymentAppRunnerFlavor`, which is + basically a descriptor of an app runner implementation that itself is a + subclass of `BaseDeploymentAppRunner`. +- `deployment_service_class` and `deployment_service_kwargs` let you provide + your own deployment service to customize the pipeline deployment logic. This + needs to be set to a subclass of `BasePipelineDeploymentService`. + +Both accept loadable sources or objects. We cover how to implement custom +runner flavors and services in a dedicated guide. diff --git a/docs/book/how-to/steps-pipelines/configuration.md b/docs/book/how-to/steps-pipelines/configuration.md index 6cef5baf142..92fce0c8fa1 100644 --- a/docs/book/how-to/steps-pipelines/configuration.md +++ b/docs/book/how-to/steps-pipelines/configuration.md @@ -135,11 +135,12 @@ This approach allows you to use different components for different steps in your ## Types of Settings -Settings in ZenML are categorized into two main types: +Settings in ZenML are categorized into three main types: * **General settings** that can be used on all ZenML pipelines: * `DockerSettings` for container configuration * `ResourceSettings` for CPU, memory, and GPU allocation + * `DeploymentSettings` for pipeline deployment configuration - can only be set at the pipeline level * **Stack-component-specific settings** for configuring behaviors of components in your stack: * These use the pattern `` or `.` as keys @@ -185,7 +186,7 @@ simple_ml_pipeline.configuration.settings["resources"] ### Resource Settings -Resource settings allow you to specify the CPU, memory, and GPU requirements for your steps: +Resource settings allow you to specify the CPU, memory, and GPU requirements for your steps. ```python from zenml.config import ResourceSettings @@ -211,6 +212,28 @@ When both pipeline and step resource settings are specified, they are merged wit Note that `ResourceSettings` are not always applied by all orchestrators. The ability to enforce resource constraints depends on the specific orchestrator being used. Some orchestrators like Kubernetes fully support these settings, while others may ignore them. In order to learn more, read the [individual pages](https://docs.zenml.io/stacks/stack-components/orchestrators) of the orchestrator you are using. {% endhint %} +Resource settings also allow you to configure scaling options - including minimum and maximum number of instances, and scaling policy - for your pipeline deployments, when used at the pipeline level: + +```python +from zenml.config import ResourceSettings + +@pipeline(settings={"resources": ResourceSettings( + cpu_count=2, + memory="4GB", + min_replicas=0, + max_replicas=10, + max_concurrency=10 +)}) +def simple_llm_pipeline(parameter: int): + ... +``` + + +{% hint style="info" %} +Note that `ResourceSettings` are not always applied exactly as specified by all deployers. Some deployers fully support these settings, while others may adjust them automatically to match a set of predefined static values or simply ignore them. In order to learn more, read the [individual pages](https://docs.zenml.io/stacks/stack-components/deployers) of the deployer you are using. +{% endhint %} + + ### Docker Settings Docker settings allow you to customize the containerization process: @@ -227,6 +250,50 @@ def my_pipeline(): For more detailed information on containerization options, see the [containerization guide](../containerization/containerization.md). +### Deployment Settings + +Deployment settings allow you to customize the web server and ASGI application used to run your pipeline deployments. You can specify a range of options, including custom endpoints, middleware, extensions and even custom files used to serve an entire single-page application alongside your pipeline: + +```python +from typing import Dict, Any +import psutil +from zenml.config import DeploymentSettings, EndpointSpec, EndpointMethod, SecureHeadersConfig +from zenml import pipeline + +async def health_detailed() -> Dict[str, Any]: + return { + "status": "healthy", + "cpu_percent": psutil.cpu_percent(), + "memory_percent": psutil.virtual_memory().percent, + "disk_percent": psutil.disk_usage("/").percent, + } + +@pipeline(settings={ + "deployment": DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/health", + method=EndpointMethod.GET, + handler=health_detailed, + auth_required=False, + ), + ], + secure_headers=SecureHeadersConfig( + csp=( + "default-src 'none'; " + "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "connect-src 'self' https://cdn.jsdelivr.net; " + "style-src 'self' 'unsafe-inline'" + ), + ), + dashboard_files_path="my/custom/ui", +}) +def my_pipeline(): + ... +``` + +For more detailed information on deployment options, see the [pipeline deployment guide](../deployment/deployment.md), particularly the [deployment settings](../deployment/deployment_settings.md) section. + ## Stack Component Configuration ### Registration-time vs Runtime Stack Component Settings @@ -369,6 +436,57 @@ settings: cpu_count: Optional[PositiveFloat] gpu_count: Optional[NonNegativeInt] memory: Optional[ConstrainedStrValue] + deployment: + api_url_path: str + app_description: Union[str, NoneType] + app_extensions: Union[List[AppExtensionSpec], NoneType] + app_kwargs: Dict[str, Any] + app_title: Union[str, NoneType] + app_version: Union[str, NoneType] + cors: + allow_credentials: bool + allow_headers: List[str] + allow_methods: List[str] + allow_origins: List[str] + custom_endpoints: Union[List[EndpointSpec], NoneType] + custom_middlewares: Union[List[MiddlewareSpec], NoneType] + dashboard_files_path: Union[str, NoneType] + deployment_app_runner_flavor: Union[Annotated[SourceOrObject, BeforeValidator, + PlainSerializer], NoneType] + deployment_app_runner_kwargs: Dict[str, Any] + deployment_service_class: Union[Annotated[SourceOrObject, BeforeValidator, PlainSerializer], + NoneType] + deployment_service_kwargs: Dict[str, Any] + docs_url_path: str + health_url_path: str + include_default_endpoints: bool + include_default_middleware: bool + info_url_path: str + invoke_url_path: str + log_level: LoggingLevels + metrics_url_path: str + redoc_url_path: str + root_url_path: str + secure_headers: + cache: Union[bool, str] + content: Union[bool, str] + csp: Union[bool, str] + hsts: Union[bool, str] + permissions: Union[bool, str] + referrer: Union[bool, str] + server: Union[bool, str] + xfo: Union[bool, str] + shutdown_hook: Union[Annotated[SourceOrObject, BeforeValidator, PlainSerializer], + NoneType] + shutdown_hook_kwargs: Dict[str, Any] + startup_hook: Union[Annotated[SourceOrObject, BeforeValidator, PlainSerializer], + NoneType] + startup_hook_kwargs: Dict[str, Any] + thread_pool_size: int + uvicorn_host: str + uvicorn_kwargs: Dict[str, Any] + uvicorn_port: int + uvicorn_workers: int steps: load_data: enable_artifact_metadata: Optional[bool] diff --git a/docs/book/toc.md b/docs/book/toc.md index 5c1f6f21791..76ed0c02339 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -48,6 +48,7 @@ * [Service Connectors](how-to/stack-components/service_connectors.md) * [Pipeline Snapshots](how-to/snapshots/snapshots.md) * [Pipeline Deployments](how-to/deployment/deployment.md) + * [Deployment Settings](how-to/deployment/deployment_settings.md) * [Containerization](how-to/containerization/containerization.md) * [Code Repositories](how-to/code-repositories/code-repositories.md) * [Secrets](how-to/secrets/secrets.md) diff --git a/examples/weather_agent/pipelines/weather_agent.py b/examples/weather_agent/pipelines/weather_agent.py index 3ede617a0c3..b77a95a08a5 100644 --- a/examples/weather_agent/pipelines/weather_agent.py +++ b/examples/weather_agent/pipelines/weather_agent.py @@ -1,25 +1,241 @@ """Weather Agent Pipeline.""" import os +import time +import uuid +from typing import Any, Awaitable, Callable, Dict +from asgiref.compatibility import guarantee_single_callable +from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + ASGISendEvent, + Scope, +) from pipelines.hooks import ( InitConfig, cleanup_hook, init_hook, ) +from starlette.middleware.gzip import GZipMiddleware from steps import analyze_weather_with_llm, get_weather from zenml import pipeline -from zenml.config import DockerSettings - -# Import enums for type-safe capture mode configuration -from zenml.config.docker_settings import PythonPackageInstaller -from zenml.config.resource_settings import ResourceSettings +from zenml.config import ( + DeploymentSettings, + DockerSettings, + EndpointMethod, + EndpointSpec, + MiddlewareSpec, + ResourceSettings, + SecureHeadersConfig, +) +from zenml.config.deployment_settings import DeploymentDefaultMiddleware +from zenml.deployers.server.app import BaseDeploymentAppRunner +from zenml.enums import LoggingLevels +from zenml.models import DeploymentResponse docker_settings = DockerSettings( requirements=["openai"], prevent_build_reuse=True, - python_package_installer=PythonPackageInstaller.UV, +) + + +async def health_detailed() -> Dict[str, Any]: + """Detailed health check with system metrics.""" + import psutil + + from zenml.client import Client + + client = Client() + + return { + "status": "healthy", + "cpu_percent": psutil.cpu_percent(), + "memory_percent": psutil.virtual_memory().percent, + "disk_percent": psutil.disk_usage("/").percent, + "zenml": client.zen_store.get_store_info().model_dump(), + } + + +class RequestTimingMiddleware: + """ASGI middleware to measure request processing time. + + Uses the standard ASGI interface (scope, receive, send) which works + across all ASGI frameworks: FastAPI, Django, Starlette, Quart, etc. + """ + + def __init__(self, app: ASGIApplication): + """Initialize the middleware. + + Args: + app: The ASGI application to wrap. + """ + self.app = guarantee_single_callable(app) + + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + """Process ASGI request with timing measurement. + + Args: + scope: ASGI connection scope (contains request info). + receive: Async callable to receive ASGI events. + send: Async callable to send ASGI events. + """ + if scope["type"] != "http": + return await self.app(scope, receive, send) + + start_time = time.time() + + async def send_wrapper(message: ASGISendEvent) -> None: + """Intercept response to add timing header.""" + if message["type"] == "http.response.start": + process_time = (time.time() - start_time) * 1000 + headers = list(message.get("headers", [])) + headers.append( + ( + b"x-process-time-ms", + str(process_time).encode(), + ) + ) + message = {**message, "headers": headers} + + await send(message) + + await self.app(scope, receive, send_wrapper) + + +def build_deployment_info( + app_runner: BaseDeploymentAppRunner, +) -> Callable[[], Awaitable[DeploymentResponse]]: + """Build the deployment info endpoint. + + Args: + app_runner: The deployment app runner. + + Returns: + The deployment info endpoint. + """ + + async def endpoint() -> DeploymentResponse: + return app_runner.deployment + + return endpoint + + +async def request_id_middleware( + app: ASGIApplication, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + header_name: str = "x-request-id", +) -> None: + """ASGI function middleware that ensures a correlation ID header exists.""" + + app = guarantee_single_callable(app) + + if scope["type"] != "http": + await app(scope, receive, send) + return + + # Reuse existing request ID if present; otherwise generate one + request_id = None + for k, v in scope.get("headers", []): + if k.decode().lower() == header_name: + request_id = v.decode() + break + + if not request_id: + request_id = str(uuid.uuid4()) + + async def send_wrapper(message: ASGISendEvent) -> None: + if message["type"] == "http.response.start": + headers = list(message.get("headers", [])) + headers.append((header_name.encode(), request_id.encode())) + message = {**message, "headers": headers} + + await send(message) + + await app(scope, receive, send_wrapper) + + +def on_startup( + app_runner: BaseDeploymentAppRunner, warm: bool = False +) -> None: + """Startup hook. + + Args: + app_runner: The deployment app runner. + warm: Whether to warm the app. + """ + print(f"Startup hook called with warm={warm}") + + +def on_shutdown( + app_runner: BaseDeploymentAppRunner, drain_timeout_s: int = 2 +) -> None: + """Shutdown hook. + + Args: + app_runner: The deployment app runner. + drain_timeout_s: The drain timeout in seconds. + """ + print(f"Shutdown hook called with drain_timeout_s={drain_timeout_s}") + + +deployment_settings = DeploymentSettings( + custom_endpoints=[ + EndpointSpec( + path="/health", + method=EndpointMethod.GET, + handler=health_detailed, + auth_required=False, + ), + EndpointSpec( + path="/deployment", + method=EndpointMethod.GET, + handler=build_deployment_info, + auth_required=True, + ), + ], + custom_middlewares=[ + MiddlewareSpec( + middleware=RequestTimingMiddleware, + order=10, + ), + MiddlewareSpec( + middleware=request_id_middleware, + order=5, + init_kwargs={"header_name": "x-request-id"}, + ), + MiddlewareSpec( + middleware=GZipMiddleware, + native=True, + order=20, + extra_kwargs={"minimum_size": 1024}, + ), + ], + dashboard_files_path="ui", + secure_headers=SecureHeadersConfig( + csp=( + "default-src 'none'; " + "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "connect-src 'self' https://cdn.jsdelivr.net; " + "style-src 'self' 'unsafe-inline'" + ), + ), + startup_hook=on_startup, + shutdown_hook=on_shutdown, + startup_hook_kwargs={"warm": True}, + shutdown_hook_kwargs={"drain_timeout_s": 2}, + include_default_middleware=DeploymentDefaultMiddleware.CORS + | DeploymentDefaultMiddleware.SECURE_HEADERS, + log_level=LoggingLevels.DEBUG, ) environment = {} @@ -34,12 +250,8 @@ on_cleanup=cleanup_hook, settings={ "docker": docker_settings, - "deployer.gcp": { - "allow_unauthenticated": True, - # "location": "us-central1", - "generate_auth_key": True, - }, - "deployer.aws": { + "deployment": deployment_settings, + "deployer": { "generate_auth_key": True, }, "resources": ResourceSettings( @@ -54,14 +266,8 @@ ) def weather_agent( city: str = "London", -) -> str: - """Weather agent pipeline optimized for run-only serving. - - Automatically uses run-only architecture for millisecond-class latency: - - Zero database writes - - Zero filesystem operations - - In-memory step output handoff - - Perfect for real-time inference +) -> tuple[Dict[str, float], str]: + """Weather agent pipeline. Args: city: City name to analyze weather for @@ -71,4 +277,4 @@ def weather_agent( """ weather_data = get_weather(city=city) result = analyze_weather_with_llm(weather_data=weather_data, city=city) - return result + return weather_data, result diff --git a/examples/weather_agent/ui/index.html b/examples/weather_agent/ui/index.html new file mode 100644 index 00000000000..03091632bd7 --- /dev/null +++ b/examples/weather_agent/ui/index.html @@ -0,0 +1,533 @@ + + +
+
+
+
+

Weather Activity Agent

+

Type a city and get curated, weather-aware + activity ideas. Powered by your ZenML deployment.

+
+ +
+ +
+

Try it

+
+ + + + + + + + + + +
+ +
+
+
+
+
+
+ + + + \ No newline at end of file diff --git a/helm/values.yaml b/helm/values.yaml index 4c16d5f3a36..bacac04949d 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -1004,10 +1004,6 @@ zenml: hsts: enabled # The `X-Frame-Options` HTTP header value. The default value is `SAMEORIGIN`. xfo: enabled - # The `X-XSS-Protection` HTTP header value. The default value is `0`. - # NOTE: this header is deprecated and should not be customized anymore. The - # `Content-Security-Policy` header should be used instead. - xxp: enabled # The `X-Content-Type-Options` HTTP header value. The default value is # `nosniff`. content: enabled diff --git a/pyproject.toml b/pyproject.toml index 617bfb71a93..fccc10b48b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ requires-python = ">=3.10,<3.14" dependencies = [ + "asgiref~=3.10.0", "click>=8.0.1,<=8.2.1", "cloudpickle>=2.0.0", "distro>=1.6.0,<2.0.0", @@ -72,7 +73,7 @@ server = [ "orjson~=3.10.0", "Jinja2", "ipinfo>=4.4.3", - "secure~=0.3.0", + "secure~=1.0.1", "tldextract~=5.1.0", "itsdangerous~=2.2.0", ] diff --git a/src/zenml/config/__init__.py b/src/zenml/config/__init__.py index 5adf1d13546..e1654db9889 100644 --- a/src/zenml/config/__init__.py +++ b/src/zenml/config/__init__.py @@ -12,6 +12,17 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Config classes.""" +from zenml.config.deployment_settings import ( + DeploymentSettings, + DeploymentDefaultEndpoints, + DeploymentDefaultMiddleware, + EndpointSpec, + EndpointMethod, + MiddlewareSpec, + AppExtensionSpec, + SecureHeadersConfig, + CORSConfig, +) from zenml.config.docker_settings import ( DockerSettings, PythonPackageInstaller, @@ -24,6 +35,15 @@ from zenml.config.cache_policy import CachePolicy __all__ = [ + "DeploymentSettings", + "DeploymentDefaultEndpoints", + "DeploymentDefaultMiddleware", + "EndpointSpec", + "EndpointMethod", + "MiddlewareSpec", + "AppExtensionSpec", + "SecureHeadersConfig", + "CORSConfig", "DockerSettings", "PythonPackageInstaller", "PythonEnvironmentExportMethod", diff --git a/src/zenml/config/build_configuration.py b/src/zenml/config/build_configuration.py index 1ff01604a44..b36610f3981 100644 --- a/src/zenml/config/build_configuration.py +++ b/src/zenml/config/build_configuration.py @@ -15,7 +15,7 @@ import hashlib import json -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from pydantic import BaseModel @@ -39,6 +39,10 @@ class BuildConfiguration(BaseModel): step_name: Name of the step for which this image will be built. entrypoint: Optional entrypoint for the image. extra_files: Extra files to include in the Docker image. + extra_requirements_files: Extra requirements to install in the + Docker image. Each key is the name of a Python requirements file to + be created and the value is the list of requirements to be + installed. """ key: str @@ -46,6 +50,7 @@ class BuildConfiguration(BaseModel): step_name: Optional[str] = None entrypoint: Optional[str] = None extra_files: Dict[str, str] = {} + extra_requirements_files: Dict[str, List[str]] = {} def compute_settings_checksum( self, @@ -73,6 +78,7 @@ def compute_settings_checksum( default=json_utils.pydantic_encoder, ) hash_.update(settings_json.encode()) + if self.entrypoint: hash_.update(self.entrypoint.encode()) @@ -93,6 +99,7 @@ def compute_settings_checksum( stack=stack, code_repository=code_repository if pass_code_repo else None, log=False, + extra_requirements_files=self.extra_requirements_files, ) ) for _, requirements, _ in requirements_files: diff --git a/src/zenml/config/constants.py b/src/zenml/config/constants.py index e0ed854c602..1cb0364f661 100644 --- a/src/zenml/config/constants.py +++ b/src/zenml/config/constants.py @@ -15,3 +15,4 @@ DOCKER_SETTINGS_KEY = "docker" RESOURCE_SETTINGS_KEY = "resources" +DEPLOYMENT_SETTINGS_KEY = "deployment" diff --git a/src/zenml/config/deployment_settings.py b/src/zenml/config/deployment_settings.py new file mode 100644 index 00000000000..2265e3a6df3 --- /dev/null +++ b/src/zenml/config/deployment_settings.py @@ -0,0 +1,752 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Deployment settings.""" + +from enum import Enum, IntFlag, auto +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Union, +) + +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) + +from zenml.config.base_settings import BaseSettings, ConfigurationLevel +from zenml.config.source import SourceOrObject, SourceOrObjectField +from zenml.enums import LoggingLevels +from zenml.logger import get_logger + +logger = get_logger(__name__) + +DEFAULT_DEPLOYMENT_APP_THREAD_POOL_SIZE = 20 + +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_HSTS = ( + "max-age=63072000; includeSubdomains" +) +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_XFO = "SAMEORIGIN" +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CONTENT = "nosniff" +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CSP = ( + "default-src 'none'; " + "script-src 'self' 'unsafe-inline'; " + "connect-src 'self'; " + "img-src 'self'; " + "style-src 'self' 'unsafe-inline'; " + "base-uri 'self'; " + "form-action 'self'; " + "font-src 'self';" + "frame-src 'self'" +) +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_REFERRER = "no-referrer-when-downgrade" +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CACHE = ( + "no-store, no-cache, must-revalidate" +) +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_PERMISSIONS = ( + "accelerometer=(), autoplay=(), camera=(), encrypted-media=(), " + "geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), " + "payment=(), sync-xhr=(), usb=()" +) +DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_REPORT_TO = "default" +DEFAULT_DEPLOYMENT_APP_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024 + + +class EndpointMethod(str, Enum): + """HTTP methods for endpoints.""" + + GET = "GET" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + DELETE = "DELETE" + + +class EndpointSpec(BaseModel): + """Endpoint specification. + + Use this class to configure a custom endpoint that must be registered on the + deployment application in a framework-agnostic way. + + The handler field can be set to one of the following: + + 1. The function or method that represents the actual endpoint + implementation. This will be registered as is. The function itself may be + framework-specific - i.e. it may use framework-specific arguments, return + values or implementation details, e.g.: + + ```python + async def my_handler(request: Request, my_param: InputModel) -> OutputModel: + ... + ``` + + 2. An endpoint builder class - this is a callable class (i.e. a class that + implements the __call__ method, which is the actual endpoint implementation) + that is used to build the endpoint. When this is used, the class constructor + must accept an argument `app_runner` of type `BaseDeploymentAppRunner` + which will be passed by the adapter at application build time. + + The adapter will also pass the `init_kwargs` to the class constructor if + configured. The following is an example of an endpoint builder class + + ```python + from zenml.deployers.server import BaseDeploymentAppRunner + + class MyHandler: + def __init__( + self, + app_runner: "BaseDeploymentAppRunner", + **kwargs: Any, + ) -> Any: + self.app_runner = app_runner + self.kwargs = kwargs + ... + + async def __call__(self, request: Request, my_param: InputModel) -> OutputModel: + ... + ``` + + 3. An endpoint builder function - this is a function that is used to build + and return the endpoint. When this is used, the adapter will call the + the provided function first and is expected to return the actual endpoint + function. The builder function must accept an argument `app_runner` of type + `BaseDeploymentAppRunner` which will be passed by the adapter at + application build time. + + The adapter will also pass the `init_kwargs` to the builder function if + configured. The following is an example of an endpoint builder function: + + ```python + from zenml.deployers.server import BaseDeploymentAppRunner + + def my_builder( + app_runner: "BaseDeploymentAppRunner", + **kwargs: Any, + ) -> Callable: + ... + + async def endpoint(request: Request, my_param: InputModel) -> OutputModel: + ... + + return endpoint + ``` + + 4. Alternatively, the middleware can be set to any framework-specific + source-loadable object that can be used directly, by setting `native` to + `True`. In this case, the framework-specific endpoint adapter will decide + what to do with the object and how to use the init_kwargs. + + Attributes: + path: URL path (e.g., "/custom/metrics"). + method: HTTP method. + handler: Handler callable or source. If this is an endpoint builder + instead of the actual endpoint implementation, the adapter will call + the provided callable first and is expected to return the actual + endpoint callable. + native: Whether the endpoint is a framework-specific source-loadable + object that can be used directly. + auth_required: Whether authentication is required. This is an + indication for the adapter to apply any configured auth dependencies + or middlewares to the endpoint. + init_kwargs: Arguments to be passed to the endpoint builder function or + class constructor, if provided. + extra_kwargs: Arbitrary framework-specific arguments to be used when + registering the endpoint. + """ + + path: str + method: EndpointMethod + handler: SourceOrObjectField + native: bool = False + auth_required: bool = True + init_kwargs: Dict[str, Any] = Field(default_factory=dict) + extra_kwargs: Dict[str, Any] = Field(default_factory=dict) + + def load_sources(self) -> None: + """Load all source strings into callables.""" + assert isinstance(self.handler, SourceOrObject) + if not self.handler.is_loaded: + self.handler.load() + + model_config = ConfigDict( + # public attributes are mutable + frozen=False, + # prevent extra attributes during model initialization + extra="ignore", + ) + + +class MiddlewareSpec(BaseModel): + """Middleware specification. + + Use this class to configure custom middleware that must be registered on + the deployment application in a framework-agnostic way. + + The middleware field can be set to one of the following: + + 1. A middleware class - this class follows the standard ASGI middleware + interface, i.e. it implements the __call__ method and takes the scope, + receive and send arguments. + + The adapter will also pass the `init_kwargs` to the class constructor if + configured. The following is an example of a middleware class: + + ```python + from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + Scope, + ) + + class MyMiddleware: + def __init__( + self, + app: ASGIApplication, + **kwargs: Any, + ) -> None: + self.app = app + self.kwargs = kwargs + + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + # Middleware logic + ... + await self.app(scope, receive, send) + ``` + + 2. A middleware function - this function follows the standard ASGI middleware + interface, i.e. it takes the ASGIApp object, scope, receive and send arguments. + The adapter will pass the `init_kwargs` to the middleware function if + configured. The following is an example of a middleware function: + + ```python + from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + Scope, + ) + + async def my_middleware( + app: ASGIApplication, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + **kwargs: Any, + ) -> None: + ... + await app(scope, receive, send) + ``` + + 3. Alternatively, the middleware can be set to any framework-specific + source-loadable object that can be used directly, by setting `native` to + `True`. In this case, the framework-specific middleware adapter will decide + what to do with the object and how to use the init_kwargs. + + Attributes: + middleware: Middleware callable or source. If this is a middleware + builder instead of the actual middleware implementation, the + adapter will call the provided callable first and is expected to + return the actual middleware callable. + native: Whether the middleware is a framework-specific source-loadable + object that can be used directly. + order: Registration order (lower = earlier in chain). + init_kwargs: Arguments to be passed to the middleware builder function + or class constructor, if provided. + extra_kwargs: Arbitrary framework-specific arguments to be passed to + the middleware constructor by the adapter. + """ + + middleware: SourceOrObjectField + native: bool = False + order: int = 0 + init_kwargs: Dict[str, Any] = Field(default_factory=dict) + extra_kwargs: Dict[str, Any] = Field(default_factory=dict) + + def load_sources(self) -> None: + """Load source string into callable.""" + assert isinstance(self.middleware, SourceOrObject) + if not self.middleware.is_loaded: + self.middleware.load() + + model_config = ConfigDict( + # public attributes are mutable + frozen=False, + # prevent extra attributes during model initialization + extra="ignore", + ) + + +class AppExtensionSpec(BaseModel): + """Configuration for a pluggable app extension. + + Extensions can be: + 1. Simple callable - this is a function that is used to apply the extension + to the app. The function must accept an argument `app_runner` of type + `BaseDeploymentAppRunner` which will be passed by the adapter at + application build time. If configured, the function will also be passed the + `extension_kwargs` as keyword arguments. This is an example: + + ```python + from zenml.deployers.server import BaseDeploymentAppRunner + + def extension(app_runner: BaseDeploymentAppRunner, **kwargs) + + @app_runner.asgi_app.get("/my-extension") + def my_extension(request: Request) -> Response: + ... + + ``` + + 2. BaseAppExtension subclass. If any kwargs are provided, they will be + passed to the class constructor. The class must also implement the `install` + method, which will be called by the adapter at application build time. This + is an example: + + ```python + from zenml.deployers.server import BaseAppExtension + from zenml.deployers.server import BaseDeploymentAppRunner + + class MyExtension(BaseAppExtension): + + def __init__(self, **kwargs): + ... + self.router = APIRouter() + ... + + def install(self, app_runner: BaseDeploymentAppRunner) -> None: + + app_runner.asgi_app.include_router(self.router) + ``` + + Attributes: + extension: Extension callable/class or source. + extension_kwargs: Configuration passed during initialization. + """ + + extension: SourceOrObjectField + extension_kwargs: Dict[str, Any] = Field(default_factory=dict) + + def load_sources(self) -> None: + """Load source string into callable.""" + assert isinstance(self.extension, SourceOrObject) + if not self.extension.is_loaded: + self.extension.load() + + def resolve_extension_handler( + self, + ) -> Callable[..., Any]: + """Resolve the extension handler from the spec. + + Returns: + The extension handler. + + Raises: + ValueError: If the extension object is not callable. + """ + assert isinstance(self.extension, SourceOrObject) + + extension = self.extension.load() + if not callable(extension): + raise ValueError( + f"The extension object {extension} must be callable" + ) + return extension + + model_config = ConfigDict( + # public attributes are mutable + frozen=False, + # prevent extra attributes during model initialization + extra="ignore", + ) + + +class CORSConfig(BaseModel): + """Configuration for CORS.""" + + allow_origins: List[str] = ["*"] + allow_methods: List[str] = ["GET", "POST", "OPTIONS"] + allow_headers: List[str] = ["*"] + allow_credentials: bool = False + + +class SecureHeadersConfig(BaseModel): + """Configuration for secure headers. + + Attributes: + server: Custom value to be set in the `Server` HTTP header to identify + the server. If not specified, or if set to one of the reserved values + `enabled`, `yes`, `true`, `on`, the `Server` header will be set to the + default value (ZenML server ID). If set to one of the reserved values + `disabled`, `no`, `none`, `false`, `off` or to an empty string, the + `Server` header will not be included in responses. + hsts: The server header value to be set in the HTTP header + `Strict-Transport-Security`. If not specified, or if set to one of + the reserved values `enabled`, `yes`, `true`, `on`, the + `Strict-Transport-Security` header will be set to the default value + (`max-age=63072000; includeSubdomains`). If set to one of the reserved + values `disabled`, `no`, `none`, `false`, `off` or to an empty string, + the `Strict-Transport-Security` header will not be included in responses. + xfo: The server header value to be set in the HTTP header + `X-Frame-Options`. If not specified, or if set to one of the + reserved values `enabled`, `yes`, `true`, `on`, the `X-Frame-Options` + header will be set to the default value (`SAMEORIGIN`). If set to + one of the reserved values `disabled`, `no`, `none`, `false`, `off` + or to an empty string, the `X-Frame-Options` header will not be + included in responses. + content: The server header value to be set in the HTTP header + `X-Content-Type-Options`. If not specified, or if set to one + of the reserved values `enabled`, `yes`, `true`, `on`, the + `X-Content-Type-Options` header will be set to the default value + (`nosniff`). If set to one of the reserved values `disabled`, `no`, + `none`, `false`, `off` or to an empty string, the + `X-Content-Type-Options` header will not be included in responses. + csp: The server header value to be set in the HTTP header + `Content-Security-Policy`. If not specified, or if set to one + of the reserved values `enabled`, `yes`, `true`, `on`, the + `Content-Security-Policy` header will be set to the default value + DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CSP. If set to one of the + reserved values `disabled`, `no`, `none`, `false`, `off` or to an + empty string, the `Content-Security-Policy` header will not be + included in responses. + referrer: The server header value to be set in the HTTP header + `Referrer-Policy`. If not specified, or if set to one of the + reserved values `enabled`, `yes`, `true`, `on`, the `Referrer-Policy` + header will be set to the default value + (`no-referrer-when-downgrade`). If set to one of the reserved values + `disabled`, `no`, `none`, `false`, `off` or to an empty string, the + `Referrer-Policy` header will not be included in responses. + cache: The server header value to be set in the HTTP header + `Cache-Control`. If not specified, or if set to one of the + reserved values `enabled`, `yes`, `true`, `on`, the `Cache-Control` + header will be set to the default value + (`no-store, no-cache, must-revalidate`). If set to one of the + reserved values `disabled`, `no`, `none`, `false`, `off` or to an + empty string, the `Cache-Control` header will not be included in + responses. + permissions: The server header value to be set in the HTTP header + `Permissions-Policy`. If not specified, or if set to one + of the reserved values `enabled`, `yes`, `true`, `on`, the + `Permissions-Policy` header will be set to the default value + DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_PERMISSIONS. If set to + one of the reserved values `disabled`, `no`, `none`, `false`, `off` + or to an empty string, the `Permissions-Policy` header will not be + included in responses. + """ + + server: Union[bool, str] = Field( + default=True, + union_mode="left_to_right", + ) + hsts: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_HSTS, + union_mode="left_to_right", + ) + xfo: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_XFO, + union_mode="left_to_right", + ) + content: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CONTENT, + union_mode="left_to_right", + ) + csp: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CSP, + union_mode="left_to_right", + ) + referrer: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_REFERRER, + union_mode="left_to_right", + ) + cache: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CACHE, + union_mode="left_to_right", + ) + permissions: Union[bool, str] = Field( + default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_PERMISSIONS, + union_mode="left_to_right", + ) + + +DEFAULT_DEPLOYMENT_APP_ROOT_URL_PATH = "" +DEFAULT_DEPLOYMENT_APP_API_URL_PATH = "" +DEFAULT_DEPLOYMENT_APP_DOCS_URL_PATH = "/docs" +DEFAULT_DEPLOYMENT_APP_REDOC_URL_PATH = "/redoc" +DEFAULT_DEPLOYMENT_APP_INVOKE_URL_PATH = "/invoke" +DEFAULT_DEPLOYMENT_APP_HEALTH_URL_PATH = "/health" +DEFAULT_DEPLOYMENT_APP_INFO_URL_PATH = "/info" +DEFAULT_DEPLOYMENT_APP_METRICS_URL_PATH = "/metrics" + + +class DeploymentDefaultEndpoints(IntFlag): + """Default endpoints for the deployment application.""" + + NONE = 0 + DOCS = auto() + REDOC = auto() + INVOKE = auto() + HEALTH = auto() + INFO = auto() + METRICS = auto() + DASHBOARD = auto() + + DOC = DOCS | REDOC + API = INVOKE | HEALTH | INFO | METRICS + ALL = DOCS | REDOC | INVOKE | HEALTH | INFO | METRICS | DASHBOARD + + +class DeploymentDefaultMiddleware(IntFlag): + """Default middleware for the deployment application.""" + + NONE = 0 + CORS = auto() + SECURE_HEADERS = auto() + + ALL = CORS | SECURE_HEADERS + + +class DeploymentSettings(BaseSettings): + """Settings for the pipeline deployment. + + Use these settings to fully customize all aspects of the uvicorn web server + and ASGI web application that constitute the pipeline deployment. + + Note that these settings are only available at the pipeline level. + + The following customizations can be used to configure aspects that are + framework-agnostic (i.e. not specific to a particular ASGI framework like + FastAPI, Django, Flask, etc.): + + * the ASGI application details: `app_title`, `app_description`, + `app_version` and `app_kwargs` + * the URL paths for the various built-in endpoints: `root_url_path`, + `api_url_path`, `docs_url_path`, `redoc_url_path`, `invoke_url_path`, + `health_url_path`, `info_url_path` and `metrics_url_path` + * the location of dashboard static files can be provided to replace the + default UI that is included with the deployment ASGI application: + `dashboard_files_path` + * which default endpoints and middleware to include: + `include_default_endpoints` and `include_default_middleware` + * the CORS configuration: `cors` + * the secure headers configuration: `secure_headers` + * the thread pool size: `thread_pool_size` + * custom application startup and shutdown hooks: `startup_hook_source`, + `shutdown_hook_source`, `startup_hook_kwargs` and `shutdown_hook_kwargs` + * uvicorn server configuration: `uvicorn_host`, `uvicorn_port`, + `uvicorn_workers` and `uvicorn_kwargs` + + In addition to the above, the following advanced features can be used to + customize the implementation-specific details of the deployment application: + + * custom endpoints (e.g. custom metrics, custom health, etc.): `custom_endpoints` + * custom middlewares (e.g. authentication, logging, etc.): `custom_middlewares` + * application building extensions - these are pluggable components that can + be used to add advanced framework-specific features like custom authentication, + logging, metrics, etc.: `app_extensions` + + Ultimately, if neither of the above are sufficient, the user can provide a + custom implementations for the two core components that are used to build + and run the deployment application itself: + + * the deployment app runner - this is the component that is responsible for + building and running the ASGI application. It is represented by the + `zenml.deployers.server.BaseDeploymentAppRunner` class. + See: `deployment_app_runner_flavor` and `deployment_app_runner_kwargs` + * the deployment service - this is the component that is responsible for + handling the business logic of the pipeline deployment. It is represented by + the `zenml.deployers.server.BaseDeploymentService` class. See: + `deployment_service_class` and `deployment_service_kwargs` + + Both of these base classes or their existing implementations can be extended + and provided as sources in the deployment settings to be loaded at runtime. + + Attributes: + app_title: Title of the deployment application. + app_description: Description of the deployment application. + app_version: Version of the deployment application. + app_kwargs: Arbitrary framework-specific keyword arguments to be passed + to the deployment ASGI application constructor. + + include_default_endpoints: Whether to include the default endpoints in + the ASGI application. Can be a boolean or a list of default endpoints + to include. See the `DeploymentDefaultEndpoints` enum for the available + default endpoints. + include_default_middleware: Whether to include the default middleware + in the ASGI application. Can be a boolean or a list of default middleware + to include. See the `DeploymentDefaultMiddleware` enum for the available + default middleware. + + root_url_path: Root URL path. + docs_url_path: URL path for the OpenAPI documentation endpoint. + redoc_url_path: URL path for the Redoc documentation endpoint. + api_url_path: URL path for the API endpoints. + invoke_url_path: URL path for the API invoke endpoint. + health_url_path: URL path for the API health check endpoint. + info_url_path: URL path for the API info endpoint. + metrics_url_path: URL path for the API metrics endpoint. + dashboard_files_path: Path where the dashboard static files (e.g. for an + single-page application) are located. This can be used to replace the + default UI that is included with the deployment ASGI application. + The referenced directory must contain at a minimum an `index.html` + file. One or more subdirectories can be included to serve static + files (e.g. /assets, /css, /js, etc.). The path must be relative to + the source root (e.g. relative to the directory where `zenml init` + was run or where the main running Python script is located). + + cors: Configuration for CORS. + secure_headers: Configuration for secure headers. + thread_pool_size: Size of the thread pool for the ASGI application. + + startup_hook: Custom startup hook for the ASGI application. + shutdown_hook: Custom shutdown hook for the ASGI application. + startup_hook_kwargs: Keyword arguments for the startup hook. + shutdown_hook_kwargs: Keyword arguments for the shutdown hook. + + custom_endpoints: Custom endpoints for the ASGI application. See the + `EndpointSpec` class for more details. + custom_middlewares: Custom middlewares for the ASGI application. See the + `MiddlewareSpec` class for more details. + app_extensions: App extensions used to build the ASGI application. See + the `AppExtensionSpec` class for more details. + + uvicorn_host: Host of the uvicorn server. + uvicorn_port: Port of the uvicorn server. + uvicorn_workers: Number of workers for the uvicorn server. + log_level: Log level for the deployment application. + uvicorn_kwargs: Keyword arguments for the uvicorn server. + + deployment_app_runner_flavor: Flavor of the deployment app runner. Must + point to a class that extends the + `zenml.deployers.server.BaseDeploymentAppRunnerFlavor` class. + deployment_app_runner_kwargs: Keyword arguments for the deployment app + runner. These will be passed to the constructor of the deployment app + runner class. + deployment_service_class: Class of the deployment service. Must point + to a class that extends the + `zenml.deployers.server.BaseDeploymentService` class. + deployment_service_kwargs: Keyword arguments for the deployment service. + These will be passed to the constructor of the deployment service class. + """ + + # These settings are only available at the pipeline level + LEVEL: ClassVar[ConfigurationLevel] = ConfigurationLevel.PIPELINE + + app_title: Optional[str] = None + app_description: Optional[str] = None + app_version: Optional[str] = None + app_kwargs: Dict[str, Any] = {} + + include_default_endpoints: DeploymentDefaultEndpoints = ( + DeploymentDefaultEndpoints.ALL + ) + include_default_middleware: DeploymentDefaultMiddleware = ( + DeploymentDefaultMiddleware.ALL + ) + + root_url_path: str = DEFAULT_DEPLOYMENT_APP_ROOT_URL_PATH + api_url_path: str = DEFAULT_DEPLOYMENT_APP_API_URL_PATH + docs_url_path: str = DEFAULT_DEPLOYMENT_APP_DOCS_URL_PATH + redoc_url_path: str = DEFAULT_DEPLOYMENT_APP_REDOC_URL_PATH + invoke_url_path: str = DEFAULT_DEPLOYMENT_APP_INVOKE_URL_PATH + health_url_path: str = DEFAULT_DEPLOYMENT_APP_HEALTH_URL_PATH + info_url_path: str = DEFAULT_DEPLOYMENT_APP_INFO_URL_PATH + metrics_url_path: str = DEFAULT_DEPLOYMENT_APP_METRICS_URL_PATH + + dashboard_files_path: Optional[str] = None + cors: CORSConfig = CORSConfig() + secure_headers: SecureHeadersConfig = SecureHeadersConfig() + + thread_pool_size: int = DEFAULT_DEPLOYMENT_APP_THREAD_POOL_SIZE + + startup_hook: Optional[SourceOrObjectField] = None + shutdown_hook: Optional[SourceOrObjectField] = None + startup_hook_kwargs: Dict[str, Any] = {} + shutdown_hook_kwargs: Dict[str, Any] = {} + + # Framework-agnostic endpoint/middleware configuration + custom_endpoints: Optional[List[EndpointSpec]] = None + custom_middlewares: Optional[List[MiddlewareSpec]] = None + + # Pluggable app extensions for advanced features + app_extensions: Optional[List[AppExtensionSpec]] = None + + uvicorn_host: str = "0.0.0.0" # nosec + uvicorn_port: int = 8000 + uvicorn_workers: int = 1 + log_level: LoggingLevels = LoggingLevels.INFO + + uvicorn_kwargs: Dict[str, Any] = {} + + deployment_app_runner_flavor: Optional[SourceOrObjectField] = None + deployment_app_runner_kwargs: Dict[str, Any] = {} + deployment_service_class: Optional[SourceOrObjectField] = None + deployment_service_kwargs: Dict[str, Any] = {} + + def load_sources(self) -> None: + """Load source string into callable.""" + if self.startup_hook is not None: + assert isinstance(self.startup_hook, SourceOrObject) + self.startup_hook.load() + if self.shutdown_hook is not None: + assert isinstance(self.shutdown_hook, SourceOrObject) + self.shutdown_hook.load() + if self.deployment_app_runner_flavor is not None: + assert isinstance( + self.deployment_app_runner_flavor, SourceOrObject + ) + self.deployment_app_runner_flavor.load() + if self.deployment_service_class is not None: + assert isinstance(self.deployment_service_class, SourceOrObject) + self.deployment_service_class.load() + + def endpoint_enabled(self, endpoint: DeploymentDefaultEndpoints) -> bool: + """Check if an endpoint is enabled. + + Args: + endpoint: The endpoint to check. + + Returns: + True if the endpoint is enabled, False otherwise. + """ + return endpoint in self.include_default_endpoints + + def middleware_enabled( + self, middleware: DeploymentDefaultMiddleware + ) -> bool: + """Check if a middleware is enabled. + + Args: + middleware: The middleware to check. + + Returns: + True if the middleware is enabled, False otherwise. + """ + return middleware in self.include_default_middleware + + model_config = ConfigDict( + # public attributes are mutable + frozen=False, + # prevent extra attributes during model initialization + extra="ignore", + ) diff --git a/src/zenml/config/docker_settings.py b/src/zenml/config/docker_settings.py index 1b398f03eb0..9274f92d49c 100644 --- a/src/zenml/config/docker_settings.py +++ b/src/zenml/config/docker_settings.py @@ -235,6 +235,7 @@ class DockerSettings(BaseSettings): ) required_integrations: List[str] = [] install_stack_requirements: bool = True + install_deployment_requirements: bool = True local_project_install_command: Optional[str] = None apt_packages: List[str] = [] environment: Dict[str, Any] = {} diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 1c059a47d5e..080be3b2b85 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -20,7 +20,11 @@ from pydantic import SerializeAsAny, field_validator from zenml.config.cache_policy import CachePolicyWithValidator -from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY +from zenml.config.constants import ( + DEPLOYMENT_SETTINGS_KEY, + DOCKER_SETTINGS_KEY, + RESOURCE_SETTINGS_KEY, +) from zenml.config.frozen_base_model import FrozenBaseModel from zenml.config.retry_config import StepRetryConfig from zenml.config.source import SourceWithValidator @@ -30,7 +34,11 @@ from zenml.utils.time_utils import utc_now if TYPE_CHECKING: - from zenml.config import DockerSettings, ResourceSettings + from zenml.config import ( + DeploymentSettings, + DockerSettings, + ResourceSettings, + ) from zenml.config.base_settings import BaseSettings, SettingsOrDict @@ -145,3 +153,20 @@ def resource_settings(self) -> "ResourceSettings": if isinstance(model_or_dict, BaseSettings): model_or_dict = model_or_dict.model_dump() return ResourceSettings.model_validate(model_or_dict) + + @property + def deployment_settings(self) -> "DeploymentSettings": + """Deployment settings of this pipeline configuration. + + Returns: + The deployment settings of this pipeline configuration. + """ + from zenml.config import DeploymentSettings + + model_or_dict: SettingsOrDict = self.settings.get( + DEPLOYMENT_SETTINGS_KEY, {} + ) + + if isinstance(model_or_dict, BaseSettings): + model_or_dict = model_or_dict.model_dump() + return DeploymentSettings.model_validate(model_or_dict) diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index da5da446a2f..9c8b3e5a82f 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -55,7 +55,6 @@ DEFAULT_ZENML_SERVER_SECURE_HEADERS_PERMISSIONS, DEFAULT_ZENML_SERVER_SECURE_HEADERS_REFERRER, DEFAULT_ZENML_SERVER_SECURE_HEADERS_XFO, - DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP, DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE, ENV_ZENML_SERVER_PREFIX, ENV_ZENML_SERVER_PRO_PREFIX, @@ -186,15 +185,6 @@ class ServerConfiguration(BaseModel): one of the reserved values `disabled`, `no`, `none`, `false`, `off` or to an empty string, the `X-Frame-Options` header will not be included in responses. - secure_headers_xxp: The server header value to be set in the HTTP - header `X-XSS-Protection`. If not specified, or if set to one of the - reserved values `enabled`, `yes`, `true`, `on`, the `X-XSS-Protection` - header will be set to the default value (`0`). If set to one of the - reserved values `disabled`, `no`, `none`, `false`, `off` or - to an empty string, the `X-XSS-Protection` header will not be - included in responses. NOTE: this header is deprecated and should - always be set to `0`. The `Content-Security-Policy` header should be - used instead. secure_headers_content: The server header value to be set in the HTTP header `X-Content-Type-Options`. If not specified, or if set to one of the reserved values `enabled`, `yes`, `true`, `on`, the @@ -325,10 +315,6 @@ class ServerConfiguration(BaseModel): default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_XFO, union_mode="left_to_right", ) - secure_headers_xxp: Union[bool, str] = Field( - default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP, - union_mode="left_to_right", - ) secure_headers_content: Union[bool, str] = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT, union_mode="left_to_right", diff --git a/src/zenml/config/source.py b/src/zenml/config/source.py index f41be94dca4..e6b4e2b3332 100644 --- a/src/zenml/config/source.py +++ b/src/zenml/config/source.py @@ -14,13 +14,15 @@ """Source classes.""" from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional +from types import BuiltinFunctionType, FunctionType, ModuleType +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union from uuid import UUID from pydantic import ( BaseModel, BeforeValidator, ConfigDict, + PlainSerializer, SerializeAsAny, field_validator, ) @@ -162,6 +164,201 @@ def model_dump_json(self, **kwargs: Any) -> str: """ return super().model_dump_json(serialize_as_any=True, **kwargs) + @classmethod + def convert_source(cls, source: Any) -> Any: + """Converts an old source string to a source object. + + Args: + source: Source string or object. + + Returns: + The converted source. + """ + if isinstance(source, str): + source = cls.from_import_path(source) + + return source + + +ObjectType = Union[ + Type[Any], + Callable[..., Any], + ModuleType, + FunctionType, + BuiltinFunctionType, +] + + +class SourceOrObject(Source): + """Hybrid type that can hold either a Source path or a loaded object (type, function, variable, etc.). + + This enables: + - Internal use: Pass actual objects directly + - External use: Pass source strings from config + - Lazy serialization: Converts objects to source path strings only at + serialization time + - Lazy loading: Only loads sources when explicitly requested + + Examples: + * the source `SourceOrObject(module="zenml.config.source", attribute="SourceOrObject")` + references the class that this docstring is describing. + * the object `SourceOrObject.from_object(object=my_object)` creates a source + or object from the object `my_object`. + + Attributes: + _object: Loaded object object (if initialized from object or loaded). + _is_loaded: Whether the callable has been loaded. + """ + + _object: Optional[ObjectType] = None + _is_loaded: bool = False + + @classmethod + def from_source(cls, source: Source) -> "SourceOrObject": + """Creates a source or object from a source instance. + + Args: + source: The source instance. + + Returns: + The source or object instance. + """ + return cls(**source.model_dump()) + + @classmethod + def from_object(cls, object: ObjectType) -> "SourceOrObject": + """Creates a source or object from an object instance. + + Args: + object: The object instance. + + Returns: + The source or object instance. + """ + # We don't resolve the object right away, we only do it at serialization + # time. This allows us to store temporary objects in this class too. + result = cls( + module="", + type=SourceType.UNKNOWN, + ) + result._object = object + result._is_loaded = True + return result + + def load(self) -> ObjectType: + """Load and return the object. + + Returns: + The loaded object. + + Raises: + RuntimeError: If loading fails. + """ + from zenml.utils import source_utils + + if not self._is_loaded: + try: + self._object = source_utils.load(self) + self._is_loaded = True + except Exception as e: + raise RuntimeError( + f"Failed to load object from {self.import_path}: {e}" + ) from e + + assert self._object is not None + return self._object + + def resolve(self) -> None: + """Resolve the stored object to a source.""" + from zenml.utils import source_utils + + if not self._is_loaded: + return + + if self.module: + # Already resolved + return + + source = source_utils.resolve(self._object) + for k, v in source.model_dump().items(): + setattr(self, k, v) + + @property + def is_loaded(self) -> bool: + """Whether the object has been loaded. + + Returns: + True if the object has been loaded, False otherwise. + """ + return self._is_loaded + + @classmethod + def convert_source_or_object( + cls, + source: Union[ + str, "SourceOrObject", Source, Dict[str, Any], ObjectType + ], + ) -> Union["SourceOrObject", Dict[str, Any]]: + """Converts a source string or object to a SourceOrObject object. + + Args: + source: Source string or object. + + Returns: + The converted source or object. + """ + if isinstance(source, str): + source = cls.from_import_path(source) + + if isinstance(source, cls): + return source + + if isinstance(source, Source): + return cls.from_source(source) + + if isinstance(source, dict): + return source + + return cls.from_object(source) + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + """Dump the source as a dictionary. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + The source as a dictionary. + """ + self.resolve() + return super().model_dump(**kwargs) + + def model_dump_json(self, **kwargs: Any) -> str: + """Dump the source as a JSON string. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + The source as a JSON string. + """ + self.resolve() + return super().model_dump_json(**kwargs) + + @classmethod + def serialize_source_or_object( + cls, value: "SourceOrObject" + ) -> Dict[str, Any]: + """Serialize the source or object as a dictionary. + + Args: + value: The source or object to serialize. + + Returns: + The source or object as a string. + """ + return value.model_dump() + class DistributionPackageSource(Source): """Source representing an object from a distribution package. @@ -283,22 +480,18 @@ def _validate_module(cls, value: str) -> str: return value -def convert_source(source: Any) -> Any: - """Converts an old source string to a source object. - - Args: - source: Source string or object. - - Returns: - The converted source. - """ - if isinstance(source, str): - source = Source.from_import_path(source) - - return source - - SourceWithValidator = Annotated[ SerializeAsAny[Source], - BeforeValidator(convert_source), + BeforeValidator(Source.convert_source), ] + +if TYPE_CHECKING: + SourceOrObjectField = Union[ObjectType, SourceOrObject, Source, str] +else: + SourceOrObjectField = Annotated[ + SourceOrObject, + BeforeValidator(SourceOrObject.convert_source_or_object), + PlainSerializer( + SourceOrObject.serialize_source_or_object, return_type=dict + ), + ] diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 3c683b6387f..0145115b372 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -319,7 +319,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: "max-age=63072000; includeSubdomains" ) DEFAULT_ZENML_SERVER_SECURE_HEADERS_XFO = "SAMEORIGIN" -DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP = "0" DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT = "nosniff" _csp_script_src_urls = ["https://widgets-v3.featureos.app"] _csp_connect_src_urls = [ diff --git a/src/zenml/deployers/containerized_deployer.py b/src/zenml/deployers/containerized_deployer.py index ad12c955221..6335193678e 100644 --- a/src/zenml/deployers/containerized_deployer.py +++ b/src/zenml/deployers/containerized_deployer.py @@ -26,6 +26,7 @@ DEPLOYER_DOCKER_IMAGE_KEY, ) from zenml.deployers.base_deployer import BaseDeployer +from zenml.deployers.utils import load_deployment_requirements from zenml.logger import get_logger from zenml.models import ( PipelineSnapshotBase, @@ -38,8 +39,6 @@ class ContainerizedDeployer(BaseDeployer, ABC): """Base class for all containerized deployers.""" - CONTAINER_REQUIREMENTS: List[str] = [] - @staticmethod def get_image(snapshot: PipelineSnapshotResponse) -> str: """Get the docker image used to deploy a pipeline snapshot. @@ -70,7 +69,6 @@ def requirements(self) -> Set[str]: A set of PyPI requirements for the deployer. """ requirements = super().requirements - requirements.update(self.CONTAINER_REQUIREMENTS) if self.config.is_local and GlobalConfiguration().uses_sql_store: # If we're directly connected to a DB, we need to install the @@ -90,9 +88,22 @@ def get_docker_builds( Returns: The required Docker builds. """ + deployment_settings = ( + snapshot.pipeline_configuration.deployment_settings + ) + docker_settings = snapshot.pipeline_configuration.docker_settings + if not docker_settings.install_deployment_requirements: + return [] + + deployment_requirements = load_deployment_requirements( + deployment_settings + ) return [ BuildConfiguration( key=DEPLOYER_DOCKER_IMAGE_KEY, settings=snapshot.pipeline_configuration.docker_settings, + extra_requirements_files={ + ".zenml_deployment_requirements": deployment_requirements, + }, ) ] diff --git a/src/zenml/deployers/docker/docker_deployer.py b/src/zenml/deployers/docker/docker_deployer.py index d0ddf07f84c..54c532305e3 100644 --- a/src/zenml/deployers/docker/docker_deployer.py +++ b/src/zenml/deployers/docker/docker_deployer.py @@ -53,9 +53,7 @@ DeploymentProvisionError, ) from zenml.deployers.server.entrypoint_configuration import ( - AUTH_KEY_OPTION, DEPLOYMENT_ID_OPTION, - PORT_OPTION, DeploymentEntrypointConfiguration, ) from zenml.enums import DeploymentStatus, StackComponentType @@ -137,7 +135,6 @@ def from_deployment( class DockerDeployer(ContainerizedDeployer): """Deployer responsible for deploying pipelines locally using Docker.""" - CONTAINER_REQUIREMENTS: List[str] = ["uvicorn", "fastapi"] _docker_client: Optional[DockerClient] = None @property @@ -302,10 +299,7 @@ def do_provision_deployment( entrypoint_kwargs = { DEPLOYMENT_ID_OPTION: deployment.id, - PORT_OPTION: 8000, } - if deployment.auth_key: - entrypoint_kwargs[AUTH_KEY_OPTION] = deployment.auth_key arguments = DeploymentEntrypointConfiguration.get_entrypoint_arguments( **entrypoint_kwargs @@ -367,7 +361,10 @@ def do_provision_deployment( range=settings.port_range, address="0.0.0.0", # nosec ) - ports: Dict[str, Optional[int]] = {"8000/tcp": port} + container_port = ( + snapshot.pipeline_configuration.deployment_settings.uvicorn_port + ) + ports: Dict[str, Optional[int]] = {f"{container_port}/tcp": port} uid_args: Dict[str, Any] = {} if sys.platform == "win32": diff --git a/src/zenml/deployers/server/__init__.py b/src/zenml/deployers/server/__init__.py new file mode 100644 index 00000000000..25ac189c541 --- /dev/null +++ b/src/zenml/deployers/server/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Deployment server web application implementation.""" + +from zenml.deployers.server.app import BaseDeploymentAppRunner +from zenml.deployers.server.extensions import BaseAppExtension +from zenml.deployers.server.adapters import ( + EndpointAdapter, + MiddlewareAdapter, +) + +__all__ = [ + "BaseDeploymentAppRunner", + "BaseAppExtension", + "EndpointAdapter", + "MiddlewareAdapter", +] \ No newline at end of file diff --git a/src/zenml/deployers/server/adapters.py b/src/zenml/deployers/server/adapters.py new file mode 100644 index 00000000000..299959ff732 --- /dev/null +++ b/src/zenml/deployers/server/adapters.py @@ -0,0 +1,240 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Framework adapter interfaces.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, cast + +from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + Scope, +) + +from zenml.config.source import SourceOrObject + +if TYPE_CHECKING: + from zenml.config import ( + EndpointSpec, + MiddlewareSpec, + ) + from zenml.deployers.server.app import BaseDeploymentAppRunner + + +class EndpointAdapter(ABC): + """Converts framework-agnostic endpoint specs to framework endpoints.""" + + def resolve_endpoint_handler( + self, + app_runner: "BaseDeploymentAppRunner", + endpoint_spec: "EndpointSpec", + ) -> Any: + """Resolve an endpoint handler from its specification. + + This method handles three types of handlers as defined in EndpointSpec: + 1. Direct endpoint function - returned as-is + 2. Endpoint builder class - instantiated with app_runner, app, and + init_kwargs + 3. Endpoint builder function - called with app_runner, app, and + init_kwargs to obtain the actual endpoint + + Args: + app_runner: Deployment app runner instance. + endpoint_spec: The endpoint specification to resolve the handler + from. + + Returns: + The actual endpoint callable ready to be registered. + + Raises: + ValueError: If handler is not callable or builder returns + non-callable. + RuntimeError: If handler resolution fails. + """ + import inspect + + assert isinstance(endpoint_spec.handler, SourceOrObject) + handler = endpoint_spec.handler.load() + + if endpoint_spec.native: + return handler + + # Type 2: Endpoint builder class + if isinstance(handler, type): + if not hasattr(handler, "__call__"): + raise ValueError( + f"Handler class {handler.__name__} must implement " + "__call__ method" + ) + try: + inner_handler = handler( + app_runner=app_runner, + **endpoint_spec.init_kwargs, + ) + except TypeError as e: + raise RuntimeError( + f"Failed to instantiate handler class " + f"{handler.__name__}: {e}" + ) from e + + if not callable(inner_handler): + raise ValueError( + f"The __call__ method of the handler class " + f"{handler.__name__} must return a callable" + ) + + return inner_handler + + if not callable(handler): + raise ValueError(f"Handler {handler} is not callable") + + # Determine if it's Type 3 (builder function) or Type 1 (direct) + try: + sig = inspect.signature(handler) + params = set(sig.parameters.keys()) + + # Type 3: Builder function (has app_runner parameter) + if "app_runner" in params: + try: + inner_handler = handler( + app_runner=app_runner, + **endpoint_spec.init_kwargs, + ) + if not callable(inner_handler): + raise ValueError( + f"Builder function {handler.__name__} must " + f"return a callable, got {type(inner_handler)}" + ) + return inner_handler + except TypeError as e: + raise RuntimeError( + f"Failed to call builder function " + f"{handler.__name__}: {e}" + ) from e + + # Type 1: Direct endpoint function + return handler + + except ValueError: + # inspect.signature failed, assume it's a direct endpoint + return handler + + @abstractmethod + def register_endpoint( + self, + app_runner: "BaseDeploymentAppRunner", + spec: "EndpointSpec", + ) -> None: + """Register an endpoint on the app. + + Args: + app_runner: Deployment app runner instance. + spec: Framework-agnostic endpoint specification. + + Raises: + RuntimeError: If endpoint registration fails. + """ + + +class MiddlewareAdapter(ABC): + """Converts framework-agnostic middleware specs to framework middleware.""" + + def resolve_middleware_handler( + self, + app_runner: "BaseDeploymentAppRunner", + middleware_spec: "MiddlewareSpec", + ) -> Any: + """Resolve a middleware handler from its specification. + + This method handles three types of middleware as defined in MiddlewareSpec: + 1. Middleware callable class + 2. Middleware callable function + 3. Native middleware object + + Args: + app_runner: Deployment app runner instance. + middleware_spec: The middleware specification to resolve the handler + from. + + Returns: + The actual middleware callable ready to be registered. + + Raises: + ValueError: If middleware is not callable or builder returns + non-callable. + """ + import inspect + + assert isinstance(middleware_spec.middleware, SourceOrObject) + middleware = middleware_spec.middleware.load() + + if middleware_spec.native: + return middleware + + # Type 1: Middleware class + if isinstance(middleware, type): + return middleware + + if not callable(middleware): + raise ValueError(f"Middleware {middleware} is not callable") + + # Wrap the middleware function in a middleware class + class _MiddlewareAdapter: + def __init__(self, app: ASGIApplication, **kwargs: Any) -> None: + self.app = app + self.kwargs = kwargs + + async def __call__( + self, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + callable_middleware = cast(Callable[..., Any], middleware) + if inspect.iscoroutinefunction(callable_middleware): + await callable_middleware( + app=self.app, + scope=scope, + receive=receive, + send=send, + **self.kwargs, + ) + else: + callable_middleware( + app=self.app, + scope=scope, + receive=receive, + send=send, + **self.kwargs, + ) + + return _MiddlewareAdapter + + @abstractmethod + def register_middleware( + self, + app_runner: "BaseDeploymentAppRunner", + spec: "MiddlewareSpec", + ) -> None: + """Register middleware on the app. + + Args: + app_runner: Deployment app runner instance. + spec: Framework-agnostic middleware specification. + + Raises: + ValueError: If middleware scope requires missing parameters. + RuntimeError: If middleware registration fails. + """ diff --git a/src/zenml/deployers/server/app.py b/src/zenml/deployers/server/app.py index 4717516d1e3..9a46d4bf2b8 100644 --- a/src/zenml/deployers/server/app.py +++ b/src/zenml/deployers/server/app.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,371 +11,1010 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""FastAPI application for running ZenML pipeline deployments.""" +"""Base deployment app runner.""" import os -from contextlib import asynccontextmanager -from typing import AsyncGenerator, Literal, Optional - -from fastapi import ( - APIRouter, - Depends, - FastAPI, - HTTPException, - Request, +import re +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) +from uuid import UUID + +from asgiref.compatibility import guarantee_single_callable +from asgiref.typing import ( + ASGIApplication, + ASGIReceiveCallable, + ASGISendCallable, + ASGISendEvent, + Scope, ) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse, JSONResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from zenml.client import Client +from zenml.config import ( + AppExtensionSpec, + DeploymentDefaultEndpoints, + DeploymentDefaultMiddleware, + DeploymentSettings, + EndpointMethod, + EndpointSpec, + MiddlewareSpec, +) +from zenml.config.source import SourceOrObject +from zenml.deployers.server.adapters import ( + EndpointAdapter, + MiddlewareAdapter, +) +from zenml.deployers.server.extensions import BaseAppExtension from zenml.deployers.server.models import ( - ExecutionMetrics, - ServiceInfo, - get_pipeline_invoke_models, + BaseDeploymentInvocationRequest, + BaseDeploymentInvocationResponse, ) -from zenml.deployers.server.service import PipelineDeploymentService +from zenml.deployers.server.service import ( + BasePipelineDeploymentService, + PipelineDeploymentService, +) +from zenml.integrations.registry import integration_registry from zenml.logger import get_logger +from zenml.models.v2.core.deployment import DeploymentResponse +from zenml.utils import source_utils + +if TYPE_CHECKING: + from secure import Secure logger = get_logger(__name__) -_service: Optional[PipelineDeploymentService] = None +class BaseDeploymentAppRunner(ABC): + """Base class for deployment app runners. + + This class is responsible for building and running the ASGI compatible web + application (e.g. FastAPI, Django, Flask, Falcon, Quart, BlackSheep, etc.) and the + associated deployment service for the pipeline deployment. It also acts as + a adaptation layer between the REST API interface and deployment service to + preserve the following separation of concerns between the two components: + + * the ASGI application is responsible for handling the HTTP requests and + responses to the user + * the deployment service is responsible for handling the business logic + + The deployment service code should be free of any ASGI application specific + code and concerns and vice-versa. This allows them to be independently + extendable and easily swappable. + + Implementations of this class must use the deployment and its settings to + configure and run the web application (e.g. FastAPI, Flask, Falcon, Quart, + BlackSheep, etc.) that wraps the deployment service according to the user's + specifications, particularly concerning the following: + + * exposed endpoints (URL paths, methods, input/output models) + * middleware (CORS, authentication, logging, etc.) + * error handling + * lifecycle management (startup, shutdown) + * custom hooks (startup, shutdown) + * app configuration (workers, host, port, thread pool size, etc.) + + The following methods must be provided by implementations of this class: + + * flavor: Return the flavor class associated with this deployment + application runner. + * build: Build and return an ASGI compatible web application (i.e. an + ASGIApplication object that can be run with uvicorn). Most Python ASGI + frameworks provide an ASGIApplication object. This method also has to + register all the endpoints, middleware and extensions that are either + required internally or supplied to it. It must also configure the `startup` + and `shutdown` methods to be run as part of the ASGI application's lifespan + or overload the `_run_asgi_app` method to handle the startup and shutdown as + an alternative. + * _get_dashboard_endpoints: Gets the dashboard endpoints specs from the + deployment configuration. Only required if the dashboard files path is set + in the deployment configuration and the app runner supports serving a + dashboard alongside the API. + * _build_cors_middleware: Builds the CORS middleware from the CORS settings + in the deployment configuration. + """ -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - """Manage application lifespan. + def __init__( + self, deployment: Union[str, UUID, "DeploymentResponse"], **kwargs: Any + ): + """Initialize the deployment app. + + Args: + deployment: The deployment to run. + **kwargs: Additional keyword arguments for the deployment app runner. + """ + self.deployment = self.load_deployment(deployment) + assert self.deployment.snapshot is not None + self.snapshot = self.deployment.snapshot + + self.settings = ( + self.snapshot.pipeline_configuration.deployment_settings + ) - Args: - app: The FastAPI application instance being deployed. + self.service = self.load_deployment_service() - Yields: - None: Control is handed back to FastAPI once initialization completes. + # Create framework-specific adapters + self.endpoint_adapter = self._create_endpoint_adapter() + self.middleware_adapter = self._create_middleware_adapter() + self._asgi_app: Optional[ASGIApplication] = None - Raises: - ValueError: If no deployment identifier is configured. - Exception: If initialization or cleanup fails. - """ - # Check for test mode - if os.getenv("ZENML_DEPLOYMENT_TEST_MODE", "false").lower() == "true": - logger.info("🧪 Running in test mode - skipping initialization") - yield - return - - # Startup - logger.info("🚀 Starting ZenML Pipeline Serving service...") - - deployment_id = os.getenv("ZENML_DEPLOYMENT_ID") - if not deployment_id: - raise ValueError( - "ZENML_DEPLOYMENT_ID environment variable is required" - ) + self.endpoints: List[EndpointSpec] = [] + self.middlewares: List[MiddlewareSpec] = [] + self.extensions: List[AppExtensionSpec] = [] - try: - global _service - _service = PipelineDeploymentService(deployment_id) - _service.initialize() - app.include_router(_build_invoke_router(_service)) - logger.info("✅ Pipeline deployment service initialized successfully") - except Exception as e: - logger.error(f"❌ Failed to initialize: {e}") - raise - - yield - - # Shutdown - logger.info("🛑 Shutting down ZenML Pipeline Deployment service...") - try: - if _service: - _service.cleanup() - logger.info( - "✅ Pipeline deployment service cleaned up successfully" + @property + def asgi_app(self) -> ASGIApplication: + """Get the ASGI application. + + Returns: + The ASGI application. + + Raises: + RuntimeError: If the ASGI application is not built yet. + """ + if self._asgi_app is None: + raise RuntimeError( + "ASGI application is not built yet. Run the deployment app runner's `build` method first." ) - except Exception as e: - logger.error(f"❌ Error during service cleanup: {e}") - finally: - # Ensure globals are reset to avoid stale references across lifecycles - _service = None - - -# Create FastAPI application with OpenAPI security scheme -app = FastAPI( - title=f"ZenML Pipeline Deployment {os.getenv('ZENML_DEPLOYMENT_ID')}", - description="deploy ZenML pipelines as FastAPI endpoints", - version="0.2.0", - lifespan=lifespan, - docs_url="/docs", - redoc_url="/redoc", -) + return self._asgi_app + + @classmethod + def load_deployment( + cls, deployment: Union[str, UUID, "DeploymentResponse"] + ) -> DeploymentResponse: + """Load the deployment. + + Args: + deployment: The deployment to load. + + Returns: + The deployment. + + Raises: + RuntimeError: If the deployment or its snapshot cannot be loaded. + """ + if isinstance(deployment, str): + deployment = UUID(deployment) + + if isinstance(deployment, UUID): + try: + deployment = Client().zen_store.get_deployment( + deployment_id=deployment + ) + except Exception as e: + raise RuntimeError( + f"Failed to load deployment {deployment}: {e}" + ) from e + else: + assert isinstance(deployment, DeploymentResponse) + + if deployment.snapshot is None: + raise RuntimeError(f"Deployment {deployment.id} has no snapshot") + + return deployment + + @classmethod + def load_app_runner( + cls, deployment: Union[str, UUID, "DeploymentResponse"] + ) -> "BaseDeploymentAppRunner": + """Load the app runner for the deployment. + + Args: + deployment: The deployment to load the app runner for. + + Returns: + The app runner for the deployment. + + Raises: + RuntimeError: If the deployment app runner cannot be loaded. + """ + deployment = cls.load_deployment(deployment) + assert deployment.snapshot is not None + + settings = ( + deployment.snapshot.pipeline_configuration.deployment_settings + ) -# Define security scheme for OpenAPI documentation -security = HTTPBearer( - scheme_name="Bearer Token", - description="Enter your API key as a Bearer token", - auto_error=False, # We handle errors in our dependency -) + app_runner_flavor = ( + BaseDeploymentAppRunnerFlavor.load_app_runner_flavor(settings) + ) + + app_runner_cls = app_runner_flavor.implementation_class + logger.info( + f"Instantiating deployment app runner class '{app_runner_cls}' for " + f"deployment {deployment.id}" + ) -def _build_invoke_router(service: PipelineDeploymentService) -> APIRouter: - """Create an idiomatic APIRouter that exposes /invoke. + try: + return app_runner_cls( + deployment, **settings.deployment_app_runner_kwargs + ) + except Exception as e: + raise RuntimeError( + f"Failed to instantiate deployment app runner class " + f"'{app_runner_cls}' for deployment {deployment.id}: {e}" + ) from e + + def load_deployment_service(self) -> BasePipelineDeploymentService: + """Load the service for the deployment. + + Returns: + The deployment service for the deployment. + + Raises: + RuntimeError: If the deployment service cannot be loaded. + """ + settings = self.snapshot.pipeline_configuration.deployment_settings + if settings.deployment_service_class is None: + service_cls: Type[BasePipelineDeploymentService] = ( + PipelineDeploymentService + ) + else: + assert isinstance( + settings.deployment_service_class, SourceOrObject + ) + try: + loaded_service_cls = settings.deployment_service_class.load() + except Exception as e: + raise RuntimeError( + f"Failed to load deployment service from source " + f"{settings.deployment_service_class}: {e}\n" + "Please check that the source is valid and that the " + "deployment service class is importable from the source " + "root directory. Hint: run `zenml init` in your local " + "source directory to initialize the source root path." + ) from e + + if not isinstance(loaded_service_cls, type) or not issubclass( + loaded_service_cls, BasePipelineDeploymentService + ): + raise RuntimeError( + f"Deployment service class '{loaded_service_cls}' is not a " + "subclass of 'BasePipelineDeploymentService'" + ) + service_cls = loaded_service_cls + + logger.info( + f"Instantiating deployment service class '{service_cls}' for " + f"deployment {self.deployment.id}" + ) - Args: - service: The deployment service used to execute pipeline runs. + try: + return service_cls(self, **settings.deployment_service_kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to instantiate deployment service class " + f"'{service_cls}' for deployment {self.deployment.id}: {e}" + ) from e + + @property + @abstractmethod + def flavor(cls) -> "BaseDeploymentAppRunnerFlavor": + """Return the flavor associated with this deployment application runner. + + Returns: + The flavor associated with this deployment application runner. + """ + + @abstractmethod + def _create_endpoint_adapter(self) -> EndpointAdapter: + """Create the framework-specific endpoint adapter. + + Returns: + Endpoint adapter instance for this framework. + """ + + @abstractmethod + def _create_middleware_adapter(self) -> MiddlewareAdapter: + """Create the framework-specific middleware adapter. + + Returns: + Middleware adapter instance for this framework. + """ + + def _build_invoke_endpoint( + self, + ) -> Callable[ + [BaseDeploymentInvocationRequest], BaseDeploymentInvocationResponse + ]: + """Create the endpoint used to invoke the pipeline deployment. + + Returns: + The invoke endpoint, built according to the pipeline deployment + input and output specifications. + """ + PipelineInvokeRequest, PipelineInvokeResponse = ( + self.service.get_pipeline_invoke_models() + ) - Returns: - A router exposing the `/invoke` endpoint wired to the service. - """ - router = APIRouter() + def _invoke_endpoint( + request: PipelineInvokeRequest, # type: ignore[valid-type] + ) -> PipelineInvokeResponse: # type: ignore[valid-type] + return self.service.execute_pipeline(request) - PipelineInvokeRequest, PipelineInvokeResponse = get_pipeline_invoke_models( - service - ) + return _invoke_endpoint - @router.post( - "/invoke", - name="invoke_pipeline", - summary="Invoke the pipeline with validated parameters", - response_model=PipelineInvokeResponse, - ) - def _( - request: PipelineInvokeRequest, # type: ignore[valid-type] - _: None = Depends(verify_token), - ) -> PipelineInvokeResponse: # type: ignore[valid-type] - return service.execute_pipeline(request) + def dashboard_files_path(self) -> Optional[str]: + """Get the absolute path of the dashboard files directory. - return router + Returns: + Absolute path. + Raises: + ValueError: If the dashboard files path is absolute. + RuntimeError: If the dashboard files path does not exist. + """ + # If an absolute path is provided, use it + dashboard_files_path = self.settings.dashboard_files_path + if not dashboard_files_path: + import zenml -def get_pipeline_service() -> PipelineDeploymentService: - """Get the pipeline deployment service. + return os.path.join( + zenml.__path__[0], "deployers", "server", "dashboard" + ) - Returns: - The initialized pipeline deployment service instance. - """ - assert _service is not None - return _service + if os.path.isabs(dashboard_files_path): + raise ValueError( + f"Dashboard files path '{dashboard_files_path}' must be " + "relative to the source root, not absolute." + ) + # Otherwise, assume this is a path relative to the source root + source_root = source_utils.get_source_root() + dashboard_path = os.path.join(source_root, dashboard_files_path) + if not os.path.exists(dashboard_path): + raise RuntimeError( + f"Dashboard files path '{dashboard_path}' does not exist. " + f"Please check that the path exists and that the source root " + f"is set correctly. Hint: run `zenml init` in your local source " + f"directory to initialize the source root path." + ) + return dashboard_path + + @abstractmethod + def _get_dashboard_endpoints(self) -> List[EndpointSpec]: + """Get the dashboard endpoints specs. + + This is called if the dashboard files path is set to construct the + endpoints specs for the dashboard. + + Returns: + The dashboard endpoints specs. + """ + + def _create_default_endpoint_specs(self) -> List[EndpointSpec]: + """Create EndpointSpec objects for default endpoints. + + Returns: + List of endpoint specs for default endpoints. + """ + specs = [] + + if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.INVOKE): + specs.append( + EndpointSpec( + path=f"{self.settings.api_url_path}{self.settings.invoke_url_path}", + method=EndpointMethod.POST, + handler=self._build_invoke_endpoint(), + auth_required=True, + ) + ) -def verify_token( - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), -) -> None: - """Verify the provided Bearer token for authentication. + if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.HEALTH): + specs.append( + EndpointSpec( + path=f"{self.settings.api_url_path}{self.settings.health_url_path}", + method=EndpointMethod.GET, + handler=self.service.health_check, + auth_required=False, + ) + ) - This dependency function integrates with FastAPI's security system - to provide proper OpenAPI documentation and authentication UI. + if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.INFO): + specs.append( + EndpointSpec( + path=f"{self.settings.api_url_path}{self.settings.info_url_path}", + method=EndpointMethod.GET, + handler=self.service.get_service_info, + auth_required=False, + ) + ) - Args: - credentials: HTTP Bearer credentials from the request + if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.METRICS): + specs.append( + EndpointSpec( + path=f"{self.settings.api_url_path}{self.settings.metrics_url_path}", + method=EndpointMethod.GET, + handler=self.service.get_execution_metrics, + auth_required=False, + ) + ) - Raises: - HTTPException: If authentication is required but token is invalid - """ - auth_key = os.getenv("ZENML_DEPLOYMENT_AUTH_KEY", "").strip() - auth_enabled = auth_key and auth_key != "" - - # If authentication is not enabled, allow all requests - if not auth_enabled: - return - - # If authentication is enabled, validate the token - if not credentials: - raise HTTPException( - status_code=401, - detail="Authorization header required", - headers={"WWW-Authenticate": "Bearer"}, + return specs + + def _get_secure_headers(self) -> "Secure": + """Get the secure headers settings. + + Returns: + The secure headers settings. + """ + import secure + + # For each of the secure headers supported by the `secure` library, we + # check if the corresponding configuration is set in the deployment + # configuration: + # + # - if set to `True`, we use the default value for the header + # - if set to a string, we use the string as the value for the header + # - if set to `False`, we don't set the header + + server: Optional[secure.Server] = None + if self.settings.secure_headers.server: + server = secure.Server() + if isinstance(self.settings.secure_headers.server, str): + server.set(self.settings.secure_headers.server) + else: + server.set(str(self.deployment.id)) + + hsts: Optional[secure.StrictTransportSecurity] = None + if self.settings.secure_headers.hsts: + hsts = secure.StrictTransportSecurity() + if isinstance(self.settings.secure_headers.hsts, str): + hsts.set(self.settings.secure_headers.hsts) + + xfo: Optional[secure.XFrameOptions] = None + if self.settings.secure_headers.xfo: + xfo = secure.XFrameOptions() + if isinstance(self.settings.secure_headers.xfo, str): + xfo.set(self.settings.secure_headers.xfo) + + csp: Optional[secure.ContentSecurityPolicy] = None + if self.settings.secure_headers.csp: + csp = secure.ContentSecurityPolicy() + if isinstance(self.settings.secure_headers.csp, str): + csp.set(self.settings.secure_headers.csp) + + xcto: Optional[secure.XContentTypeOptions] = None + if self.settings.secure_headers.content: + xcto = secure.XContentTypeOptions() + if isinstance(self.settings.secure_headers.content, str): + xcto.set(self.settings.secure_headers.content) + + referrer: Optional[secure.ReferrerPolicy] = None + if self.settings.secure_headers.referrer: + referrer = secure.ReferrerPolicy() + if isinstance(self.settings.secure_headers.referrer, str): + referrer.set(self.settings.secure_headers.referrer) + + cache: Optional[secure.CacheControl] = None + if self.settings.secure_headers.cache: + cache = secure.CacheControl() + if isinstance(self.settings.secure_headers.cache, str): + cache.set(self.settings.secure_headers.cache) + + permissions: Optional[secure.PermissionsPolicy] = None + if self.settings.secure_headers.permissions: + permissions = secure.PermissionsPolicy() + if isinstance(self.settings.secure_headers.permissions, str): + # This one is special, because it doesn't allow setting the + # value as a string, but rather as a list of directives, so we + # hack our way around it by setting the private _default_value + # attribute. + permissions._default_value = ( + self.settings.secure_headers.permissions + ) + + return secure.Secure( + server=server, + hsts=hsts, + xfo=xfo, + csp=csp, + xcto=xcto, + referrer=referrer, + cache=cache, + permissions=permissions, ) - if credentials.credentials != auth_key: - raise HTTPException( - status_code=401, - detail="Invalid authentication token", - headers={"WWW-Authenticate": "Bearer"}, + def _build_secure_headers_middleware( + self, + ) -> MiddlewareSpec: + """Get the secure headers middleware. + + Returns: + The secure headers middleware spec. + """ + secure_headers = self._get_secure_headers() + + async def set_secure_headers( + app: ASGIApplication, + scope: Scope, + receive: ASGIReceiveCallable, + send: ASGISendCallable, + ) -> None: + skip = False + if scope["type"] != "http": + skip = True + else: + path = scope["path"] + + if path.startswith( + self.settings.docs_url_path + ) or path.startswith(self.settings.redoc_url_path): + skip = True + + async def send_wrapper(message: ASGISendEvent) -> None: + if message["type"] == "http.response.start": + hdrs: List[Tuple[bytes, bytes]] = list( + message.get("headers", []) + ) + existing = {k: i for i, (k, _) in enumerate(hdrs)} + for k, v in secure_headers.headers.items(): + i = existing.get(k.encode()) + if i is not None: + hdrs[i] = (k.encode(), v.encode()) + else: + hdrs.append((k.encode(), v.encode())) + message["headers"] = hdrs + await send(message) + + wrapped_app = guarantee_single_callable(app) # type: ignore[no-untyped-call] + + await wrapped_app(scope, receive, send if skip else send_wrapper) + + return MiddlewareSpec( + middleware=set_secure_headers, ) - # Token is valid, authentication successful - return - + @abstractmethod + def _build_cors_middleware(self) -> MiddlewareSpec: + """Get the CORS middleware. + + Returns: + The CORS middleware spec. + """ + + def _create_default_middleware_specs(self) -> List[MiddlewareSpec]: + """Create MiddlewareSpec objects for default middleware. + + Returns: + List of middleware specs for default middleware. + """ + specs = [] + + if self.settings.middleware_enabled( + DeploymentDefaultMiddleware.SECURE_HEADERS + ): + specs.append(self._build_secure_headers_middleware()) + + if self.settings.middleware_enabled(DeploymentDefaultMiddleware.CORS): + specs.append(self._build_cors_middleware()) + + return specs + + def install_extensions(self, *extension_specs: AppExtensionSpec) -> None: + """Install the given app extensions. + + Args: + extension_specs: The app extensions to install. + + Raises: + ValueError: If the extension is not a subclass of BaseAppExtension. + RuntimeError: If the extension cannot be initialized. + """ + for ext_spec in extension_specs: + # Load extension + ext_spec.load_sources() + extension_obj = ext_spec.resolve_extension_handler() + + # Handle callable vs class-based extensions + if isinstance(extension_obj, type): + if not issubclass(extension_obj, BaseAppExtension): + raise ValueError( + f"Extension type {extension_obj} is not a subclass of " + "BaseAppExtension" + ) + + try: + extension_instance = extension_obj( + **ext_spec.extension_kwargs + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize extension class {extension_obj}: {e}" + ) from e + + extension_instance.install(self) + else: + # Simple callable extension + extension_obj( + app_runner=self, + **ext_spec.extension_kwargs, + ) + self.extensions.append(ext_spec) + + def register_endpoints(self, *endpoint_specs: EndpointSpec) -> None: + """Register the given endpoints. + + Args: + endpoint_specs: The endpoints to register. + """ + for endpoint_spec in endpoint_specs: + endpoint_spec.load_sources() + self.endpoint_adapter.register_endpoint(self, endpoint_spec) + self.endpoints.append(endpoint_spec) + + def register_middlewares(self, *middleware_specs: MiddlewareSpec) -> None: + """Register the given middleware. + + Args: + middleware_specs: The middleware to register. + """ + for middleware_spec in middleware_specs: + middleware_spec.load_sources() + self.middleware_adapter.register_middleware(self, middleware_spec) + self.middlewares.append(middleware_spec) + + def _run_startup_hook(self) -> None: + """Run the startup hook. + + Raises: + ValueError: If the startup hook is not callable. + Exception: If the startup hook fails to execute. + """ + if not self.settings.startup_hook: + return + + assert isinstance(self.settings.startup_hook, SourceOrObject) + startup_hook = self.settings.startup_hook.load() + + if not callable(startup_hook): + raise ValueError( + f"The startup hook object {startup_hook} must be callable" + ) -# Add CORS middleware to allow frontend access -# TODO: In production, restrict allow_origins to specific domains for security -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Allow all origins - restrict in production - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], - allow_headers=["*"], -) + logger.info("Executing the deployment application startup hook...") + try: + startup_hook( + app_runner=self, + **self.settings.startup_hook_kwargs, + ) + except Exception as e: + logger.exception(f"Failed to execute startup hook: {e}") + raise + def startup(self) -> None: + """Startup the deployment app. -@app.get("/", response_class=HTMLResponse) -async def root( - service: PipelineDeploymentService = Depends(get_pipeline_service), -) -> str: - """Root endpoint with service information. + Raises: + Exception: If the service initialization fails. + """ + logger.info("🚀 Initializing the pipeline deployment service...") - Args: - service: The pipeline serving service dependency. + try: + self.service.initialize() + logger.info( + "✅ Pipeline deployment service initialized successfully" + ) + except Exception as e: + logger.error( + f"❌ Failed to initialize the pipeline deployment service: {e}" + ) + raise - Returns: - An HTML page describing the serving deployment. - """ - info = service.get_service_info() - - html_content = f""" - - - - ZenML Pipeline Deployment - - - -

🚀 ZenML Pipeline Deployment

-
-

Service Status

-

Status: Running

-

Pipeline: {info.pipeline.name}

-
- - - - """ - return html_content + self._run_startup_hook() + def _run_shutdown_hook(self) -> None: + """Run the shutdown hook. -@app.get("/health") -async def health_check( - service: PipelineDeploymentService = Depends(get_pipeline_service), -) -> Literal["OK"]: - """Service health check endpoint. + Raises: + ValueError: If the shutdown hook is not callable. + Exception: If the shutdown hook fails to execute. + """ + if not self.settings.shutdown_hook: + return - Args: - service: The pipeline serving service dependency. + assert isinstance(self.settings.shutdown_hook, SourceOrObject) - Returns: - "OK" if the service is healthy, otherwise raises an HTTPException. + shutdown_hook = self.settings.shutdown_hook.load() - Raises: - HTTPException: If the service is not healthy. - """ - if not service.is_healthy(): - raise HTTPException(503, "Service is unhealthy") + if not shutdown_hook: + return - return "OK" + if not callable(shutdown_hook): + raise ValueError( + f"The shutdown hook object {shutdown_hook} must be callable" + ) + logger.info("Executing the deployment application shutdown hook...") + try: + shutdown_hook( + app_runner=self, + **self.settings.shutdown_hook_kwargs, + ) + except Exception as e: + logger.exception(f"Failed to execute shutdown hook: {e}") + raise -@app.get("/info") -async def info( - service: PipelineDeploymentService = Depends(get_pipeline_service), -) -> ServiceInfo: - """Get detailed information about the service, including pipeline metadata and schema. + def shutdown(self) -> None: + """Shutdown the deployment app. - Args: - service: The pipeline serving service dependency. + Raises: + Exception: If the service cleanup fails. + """ + self._run_shutdown_hook() - Returns: - Service info. - """ - return service.get_service_info() + logger.info("🛑 Cleaning up the pipeline deployment service...") + try: + self.service.cleanup() + logger.info( + "✅ The pipeline deployment service was cleaned up successfully" + ) + except Exception as e: + logger.error( + f"❌ Failed to clean up the pipeline deployment service: {e}" + ) + raise + def _build_asgi_app(self) -> ASGIApplication: + """Build the ASGI application. -@app.get("/metrics") -async def execution_metrics( - service: PipelineDeploymentService = Depends(get_pipeline_service), -) -> ExecutionMetrics: - """Get pipeline execution metrics and statistics. + Returns: + The ASGI application. + """ + endpoints = self._create_default_endpoint_specs() - Args: - service: The pipeline serving service dependency. + custom_endpoints = ( + self.settings.custom_endpoints + if self.settings.custom_endpoints + else [] + ) - Returns: - Aggregated execution metrics. - """ - return service.get_execution_metrics() + # Allow custom endpoints to override default endpoints when they share + # the same path and method. + + def normalize_path(path: str) -> str: + if path and path != "/": + path = re.sub(r"(? {param:path} + path = re.sub(r"\{[^{}:/]+:path\}", "{param:path}", path) + + # normalize untyped params: {name} -> {param} + path = re.sub(r"\{[^{}:/]+\}", "{param}", path) + + return path + + custom_keys = { + (e.method, normalize_path(e.path)) for e in custom_endpoints + } + endpoints = [ + e + for e in endpoints + if (e.method, normalize_path(e.path)) not in custom_keys + ] + + endpoints.extend(custom_endpoints) + + if self.settings.endpoint_enabled( + DeploymentDefaultEndpoints.DASHBOARD + ): + endpoints.extend(self._get_dashboard_endpoints()) + + middlewares = self._create_default_middleware_specs() + + if self.settings.custom_middlewares: + middlewares.extend(self.settings.custom_middlewares) + + extensions = [] + if self.settings.app_extensions: + extensions.extend(self.settings.app_extensions) + + return self.build(middlewares, endpoints, extensions) + + def _run_asgi_app(self, asgi_app: ASGIApplication) -> None: + """Run the ASGI application. + + Args: + asgi_app: The ASGI application to run. + + Raises: + Exception: If the application fails to start. + """ + import uvicorn + + settings = self.settings + + logger.info(f""" +🚀 Starting ZenML pipeline deployment application: + Deployment ID: {self.deployment.id} + Deployment Name: {self.deployment.name} + Snapshot ID: {self.snapshot.id} + Snapshot Name: {self.snapshot.name or "N/A"} + Pipeline ID: {self.snapshot.pipeline.id} + Pipeline Name: {self.snapshot.pipeline.name} + Host: {settings.uvicorn_host} + Port: {settings.uvicorn_port} + Workers: {settings.uvicorn_workers} + Log Level: {settings.log_level} +""") + + uvicorn_kwargs: Dict[str, Any] = dict( + host=settings.uvicorn_host, + port=settings.uvicorn_port, + workers=settings.uvicorn_workers, + log_level=settings.log_level.value, + access_log=True, + ) + if settings.uvicorn_kwargs: + uvicorn_kwargs.update(settings.uvicorn_kwargs) + + try: + # Start the ASGI application + uvicorn.run( + asgi_app, + **uvicorn_kwargs, + ) + except KeyboardInterrupt: + logger.info("\n🛑 Deployment application shutdown") + except Exception as e: + logger.error( + f"❌ Failed to start deployment application: {str(e)}" + ) + raise + def run(self) -> None: + """Run the deployment app.""" + if self._asgi_app is None: + self._build_asgi_app() -# Custom exception handlers -@app.exception_handler(ValueError) -def value_error_handler(request: Request, exc: ValueError) -> JSONResponse: - """Handle ValueError exceptions (synchronous for unit tests). + self._run_asgi_app(self.asgi_app) - Args: - request: The request. - exc: The exception. + @abstractmethod + def build( + self, + middlewares: List[MiddlewareSpec], + endpoints: List[EndpointSpec], + extensions: List[AppExtensionSpec], + ) -> ASGIApplication: + """Build the ASGI compatible web application. - Returns: - The error response. - """ - logger.error("ValueError in request: %s", exc) - return JSONResponse(status_code=400, content={"detail": str(exc)}) + Args: + middlewares: The middleware to register. + endpoints: The endpoints to register. + extensions: The extensions to install. + Returns: + The ASGI compatible web application. + """ -@app.exception_handler(RuntimeError) -def runtime_error_handler(request: Request, exc: RuntimeError) -> JSONResponse: - """Handle RuntimeError exceptions (synchronous for unit tests). - Args: - request: The request. - exc: The exception. +class BaseDeploymentAppRunnerFlavor(ABC): + """Base class for deployment app runner flavors. - Returns: - The error response. + BaseDeploymentAppRunner implementations must also provide implementations + for this class. The flavor class implementation should be kept separate from + the implementation class to allow it to be imported without importing the + implementation class and all its dependencies. """ - logger.error("RuntimeError in request: %s", exc) - return JSONResponse(status_code=500, content={"detail": str(exc)}) + + @property + @abstractmethod + def name(self) -> str: + """The name of the deployment app runner flavor. + + Returns: + The name of the deployment app runner flavor. + """ + + @property + @abstractmethod + def implementation_class(self) -> Type[BaseDeploymentAppRunner]: + """The class that implements the deployment app runner. + + Returns: + The implementation class for the deployment app runner. + """ + + @property + def requirements(self) -> List[str]: + """The software requirements for the deployment app runner. + + Returns: + The software requirements for the deployment app runner. + """ + return ["uvicorn", "secure~=1.0.1", "asgiref~=3.10.0", "Jinja2"] + + @classmethod + def load_app_runner_flavor( + cls, settings: DeploymentSettings + ) -> "BaseDeploymentAppRunnerFlavor": + """Load the app runner flavor for the deployment settings. + + Args: + settings: The deployment settings to load the app runner flavor for. + + Returns: + The app runner flavor for the deployment. + + Raises: + RuntimeError: If the deployment app runner flavor cannot be loaded. + """ + from zenml.deployers.server.fastapi import ( + FastAPIDeploymentAppRunnerFlavor, + ) + + if settings.deployment_app_runner_flavor is None: + app_runner_flavor_class: Type[BaseDeploymentAppRunnerFlavor] = ( + FastAPIDeploymentAppRunnerFlavor + ) + else: + assert isinstance( + settings.deployment_app_runner_flavor, SourceOrObject + ) + try: + loaded_app_runner_flavor_class = ( + settings.deployment_app_runner_flavor.load() + ) + except Exception as e: + raise RuntimeError( + f"Failed to load deployment app runner flavor from source " + f"{settings.deployment_app_runner_flavor}: {e}\n" + "Please check that the source is valid and that the " + "deployment app runner flavor class is importable from the " + "source root directory. Hint: run `zenml init` in your " + "local source directory to initialize the source root path." + ) from e + + if not isinstance( + loaded_app_runner_flavor_class, type + ) or not issubclass( + loaded_app_runner_flavor_class, BaseDeploymentAppRunnerFlavor + ): + raise RuntimeError( + f"The object '{loaded_app_runner_flavor_class}' is not a " + "subclass of 'BaseDeploymentAppRunnerFlavor'" + ) + + app_runner_flavor_class = loaded_app_runner_flavor_class + + try: + app_runner_flavor = app_runner_flavor_class() + except Exception as e: + raise RuntimeError( + f"Failed to instantiate deployment app runner flavor " + f"'{loaded_app_runner_flavor_class}': {e}" + ) from e + + return app_runner_flavor if __name__ == "__main__": import argparse - import uvicorn - parser = argparse.ArgumentParser() parser.add_argument( "--deployment_id", default=os.getenv("ZENML_DEPLOYMENT_ID"), help="Pipeline snapshot ID", ) - parser.add_argument( - "--host", - default=os.getenv("ZENML_SERVICE_HOST", "0.0.0.0"), # nosec - ) - parser.add_argument( - "--port", - type=int, - default=int(os.getenv("ZENML_SERVICE_PORT", "8001")), - ) - parser.add_argument( - "--workers", - type=int, - default=int(os.getenv("ZENML_SERVICE_WORKERS", "1")), - ) - parser.add_argument( - "--log_level", default=os.getenv("ZENML_LOG_LEVEL", "info").lower() - ) - parser.add_argument( - "--auth_key", default=os.getenv("ZENML_DEPLOYMENT_AUTH_KEY", "") - ) args = parser.parse_args() - if args.deployment_id: - os.environ["ZENML_DEPLOYMENT_ID"] = args.deployment_id - if args.auth_key: - os.environ["ZENML_DEPLOYMENT_AUTH_KEY"] = args.auth_key + logger.info( + f"Starting deployment application server for deployment " + f"{args.deployment_id}" + ) - logger.info(f"Starting FastAPI server on {args.host}:{args.port}") + # Activate integrations to ensure all components are available + integration_registry.activate_integrations() - uvicorn.run( - "zenml.deployers.server.app:app", - host=args.host, - port=args.port, - workers=args.workers, - log_level=args.log_level, - reload=False, - ) + app_runner = BaseDeploymentAppRunner.load_app_runner(args.deployment_id) + app_runner.run() diff --git a/src/zenml/deployers/server/dashboard/index.html b/src/zenml/deployers/server/dashboard/index.html new file mode 100644 index 00000000000..f8a6c45be69 --- /dev/null +++ b/src/zenml/deployers/server/dashboard/index.html @@ -0,0 +1,17 @@ + + +

🚀 ZenML Pipeline Deployment

+
+

Service Status

+

Status: Running

+

Pipeline: {{ service_info.pipeline.name }}

+
+ \ No newline at end of file diff --git a/src/zenml/deployers/server/entrypoint_configuration.py b/src/zenml/deployers/server/entrypoint_configuration.py index 792ca535bc9..ef9f1998dea 100644 --- a/src/zenml/deployers/server/entrypoint_configuration.py +++ b/src/zenml/deployers/server/entrypoint_configuration.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # permissions and limitations under the License. """ZenML Pipeline Deployment Entrypoint Configuration.""" -import os from typing import Any, List, Set from uuid import UUID @@ -23,19 +22,13 @@ ) from zenml.integrations.registry import integration_registry from zenml.logger import get_logger -from zenml.models.v2.core.pipeline_snapshot import PipelineSnapshotResponse +from zenml.models import DeploymentResponse from zenml.utils import uuid_utils logger = get_logger(__name__) # Deployment-specific entrypoint options DEPLOYMENT_ID_OPTION = "deployment_id" -HOST_OPTION = "host" -PORT_OPTION = "port" -WORKERS_OPTION = "workers" -LOG_LEVEL_OPTION = "log_level" -CREATE_RUNS_OPTION = "create_runs" -AUTH_KEY_OPTION = "auth_key" class DeploymentEntrypointConfiguration(BaseEntrypointConfiguration): @@ -54,12 +47,6 @@ def get_entrypoint_options(cls) -> Set[str]: """ return { DEPLOYMENT_ID_OPTION, - HOST_OPTION, - PORT_OPTION, - WORKERS_OPTION, - LOG_LEVEL_OPTION, - CREATE_RUNS_OPTION, - AUTH_KEY_OPTION, } @classmethod @@ -91,38 +78,21 @@ def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]: deployment_args = [ f"--{DEPLOYMENT_ID_OPTION}", str(kwargs.get(DEPLOYMENT_ID_OPTION, "")), - f"--{HOST_OPTION}", - str(kwargs.get(HOST_OPTION, "0.0.0.0")), # nosec - f"--{PORT_OPTION}", - str(kwargs.get(PORT_OPTION, 8001)), - f"--{WORKERS_OPTION}", - str(kwargs.get(WORKERS_OPTION, 1)), - f"--{LOG_LEVEL_OPTION}", - str(kwargs.get(LOG_LEVEL_OPTION, "info")), - f"--{CREATE_RUNS_OPTION}", - str(kwargs.get(CREATE_RUNS_OPTION, "false")), - f"--{AUTH_KEY_OPTION}", - str(kwargs.get(AUTH_KEY_OPTION, "")), ] return base_args + deployment_args - def load_snapshot(self) -> "PipelineSnapshotResponse": - """Loads the deployment snapshot. + def load_deployment(self) -> "DeploymentResponse": + """Loads the deployment. Returns: - The deployment snapshot. - - Raises: - RuntimeError: If the deployment has no snapshot. + The deployment. """ deployment_id = UUID(self.entrypoint_args[DEPLOYMENT_ID_OPTION]) deployment = Client().zen_store.get_deployment( deployment_id=deployment_id ) - if deployment.snapshot is None: - raise RuntimeError("Deployment has no snapshot") - return deployment.snapshot + return deployment def run(self) -> None: """Run the ZenML pipeline deployment application. @@ -131,62 +101,19 @@ def run(self) -> None: and the specified pipeline deployment. Raises: - Exception: If the server fails to start. + RuntimeError: If the deployment has no snapshot. """ - import uvicorn + from zenml.deployers.server.app import BaseDeploymentAppRunner # Activate integrations to ensure all components are available integration_registry.activate_integrations() - # Extract configuration from entrypoint args - deployment_id = self.entrypoint_args[DEPLOYMENT_ID_OPTION] - host = self.entrypoint_args.get(HOST_OPTION, "0.0.0.0") # nosec - port = int(self.entrypoint_args.get(PORT_OPTION, 8001)) - workers = int(self.entrypoint_args.get(WORKERS_OPTION, 1)) - log_level = self.entrypoint_args.get(LOG_LEVEL_OPTION, "info") - create_runs = ( - self.entrypoint_args.get(CREATE_RUNS_OPTION, "false").lower() - == "true" - ) - auth_key = self.entrypoint_args.get(AUTH_KEY_OPTION, None) - - snapshot = self.load_snapshot() + deployment = self.load_deployment() + if not deployment.snapshot: + raise RuntimeError(f"Deployment {deployment.id} has no snapshot") # Download code if necessary (for remote execution environments) - self.download_code_if_necessary(snapshot=snapshot) - - # Set environment variables for the deployment application - os.environ["ZENML_DEPLOYMENT_ID"] = deployment_id - if create_runs: - os.environ["ZENML_DEPLOYMENT_CREATE_RUNS"] = "true" - if auth_key: - os.environ["ZENML_DEPLOYMENT_AUTH_KEY"] = auth_key - - logger.info("🚀 Starting ZenML Pipeline Deployment...") - logger.info(f" Deployment ID: {deployment_id}") - logger.info(f" Snapshot ID: {snapshot.id}") - logger.info(f" Host: {host}") - logger.info(f" Port: {port}") - logger.info(f" Workers: {workers}") - logger.info(f" Log Level: {log_level}") - logger.info(f" Create Runs: {create_runs}") - logger.info("") - logger.info(f"📖 API Documentation: http://{host}:{port}/docs") - logger.info(f"🔍 Health Check: http://{host}:{port}/health") - logger.info("") - - try: - # Start the FastAPI server - uvicorn.run( - "zenml.deployers.server.app:app", - host=host, - port=port, - workers=workers, - log_level=log_level.lower(), - access_log=True, - ) - except KeyboardInterrupt: - logger.info("\n🛑 Deployment stopped by user") - except Exception as e: - logger.error(f"❌ Failed to start deployment: {str(e)}") - raise + self.download_code_if_necessary(snapshot=deployment.snapshot) + + app_runner = BaseDeploymentAppRunner.load_app_runner(deployment) + app_runner.run() diff --git a/src/zenml/deployers/server/extensions.py b/src/zenml/deployers/server/extensions.py new file mode 100644 index 00000000000..179bb1640be --- /dev/null +++ b/src/zenml/deployers/server/extensions.py @@ -0,0 +1,49 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Base app extension interface.""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from zenml.deployers.server.app import BaseDeploymentAppRunner + + +class BaseAppExtension(ABC): + """Abstract base for app extensions. + + Extensions provide advanced framework-specific capabilities like: + - Custom authentication/authorization + - Observability (logging, tracing, metrics) + - Complex routers with framework-specific features + - OpenAPI customizations + - Advanced middleware patterns + + Subclasses must implement install() to modify the app. + """ + + @abstractmethod + def install( + self, + app_runner: "BaseDeploymentAppRunner", + ) -> None: + """Install extension into the application. + + Args: + app_runner: The deployment app runner instance being used to build + and run the web application. + + Raises: + RuntimeError: If installation fails. + """ diff --git a/src/zenml/deployers/server/fastapi/__init__.py b/src/zenml/deployers/server/fastapi/__init__.py new file mode 100644 index 00000000000..b4ef6cccc46 --- /dev/null +++ b/src/zenml/deployers/server/fastapi/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""FastAPI implementation of the deployment app factory and adapters.""" + + +from typing import List, Type +from zenml.deployers.server.app import BaseDeploymentAppRunner, BaseDeploymentAppRunnerFlavor + +FASTAPI_APP_RUNNER_FLAVOR_NAME = "fastapi" + +class FastAPIDeploymentAppRunnerFlavor(BaseDeploymentAppRunnerFlavor): + """FastAPI deployment app runner flavor.""" + + @property + def name(self) -> str: + """The name of the deployment app runner flavor. + + Returns: + The name of the deployment app runner flavor. + """ + return FASTAPI_APP_RUNNER_FLAVOR_NAME + + @property + def implementation_class(self) -> Type[BaseDeploymentAppRunner]: + """The class that implements the deployment app runner. + + Returns: + The implementation class for the deployment app runner. + """ + from zenml.deployers.server.fastapi.app import FastAPIDeploymentAppRunner + return FastAPIDeploymentAppRunner + + @property + def requirements(self) -> List[str]: + """The software requirements for the deployment app runner. + + Returns: + The software requirements for the deployment app runner. + """ + return super().requirements + ["fastapi"] \ No newline at end of file diff --git a/src/zenml/deployers/server/fastapi/adapters.py b/src/zenml/deployers/server/fastapi/adapters.py new file mode 100644 index 00000000000..ed7ade426e9 --- /dev/null +++ b/src/zenml/deployers/server/fastapi/adapters.py @@ -0,0 +1,220 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""FastAPI adapter implementations.""" + +from typing import Any, Callable, Dict, Optional + +from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from starlette.middleware.base import BaseHTTPMiddleware + +from zenml.config import ( + EndpointMethod, + EndpointSpec, + MiddlewareSpec, +) +from zenml.deployers.server.adapters import ( + EndpointAdapter, + MiddlewareAdapter, +) +from zenml.deployers.server.app import BaseDeploymentAppRunner + + +class FastAPIEndpointAdapter(EndpointAdapter): + """FastAPI implementation of endpoint adapter.""" + + def _build_auth_dependency(self, api_key: str) -> Callable[..., Any]: + """Build a FastAPI auth dependency. + + Args: + api_key: The API key to use for authentication. + + Returns: + FastAPI auth enforcement callable. + """ + security = HTTPBearer( + scheme_name="Bearer Token", + description="Enter your API key as a Bearer token", + auto_error=False, + ) + + def verify_token( + credentials: Optional[HTTPAuthorizationCredentials] = Depends( + security + ), + ) -> None: + """Verify the provided Bearer token for authentication. + + Args: + credentials: HTTP Bearer credentials from the request. + + Raises: + HTTPException: If token is invalid. + """ + if not credentials: + raise HTTPException( + status_code=401, + detail="Authorization header required", + headers={"WWW-Authenticate": "Bearer"}, + ) + if credentials.credentials != api_key: + raise HTTPException( + status_code=401, + detail="Invalid authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return verify_token + + def register_endpoint( + self, + app_runner: BaseDeploymentAppRunner, + spec: EndpointSpec, + ) -> None: + """Register endpoint with FastAPI. + + Args: + app_runner: Deployment app runner instance. + spec: Framework-agnostic endpoint specification. + + Raises: + RuntimeError: If the adapter is not used with a FastAPI application. + """ + app = app_runner.asgi_app + + if not isinstance(app, FastAPI): + raise RuntimeError( + f"The {self.__class__.__name__} adapter must be used with a " + "FastAPI application" + ) + + # Ensure handler is loaded + handler = self.resolve_endpoint_handler(app_runner, spec) + + # Apply auth dependency if required + dependencies = [] + if spec.auth_required and app_runner.deployment.auth_key: + auth_dependency = self._build_auth_dependency( + app_runner.deployment.auth_key + ) + dependencies.append(Depends(auth_dependency)) + + if spec.native: + if isinstance(handler, APIRouter): + app.include_router( + handler, prefix=spec.path, **spec.extra_kwargs + ) + return + + # Register with appropriate HTTP method + route_kwargs: Dict[str, Any] = {"dependencies": dependencies} + route_kwargs.update(spec.extra_kwargs) + + if spec.method == EndpointMethod.GET: + app.get(spec.path, **route_kwargs)(handler) + elif spec.method == EndpointMethod.POST: + app.post(spec.path, **route_kwargs)(handler) + elif spec.method == EndpointMethod.PUT: + app.put(spec.path, **route_kwargs)(handler) + elif spec.method == EndpointMethod.PATCH: + app.patch(spec.path, **route_kwargs)(handler) + elif spec.method == EndpointMethod.DELETE: + app.delete(spec.path, **route_kwargs)(handler) + + +class FastAPIMiddlewareAdapter(MiddlewareAdapter): + """FastAPI implementation of middleware adapter. + + We support two types of native middleware: + + * A middleware class like that receives the ASGIApp object in the + constructor and implements the __call__ method to dispatch the middleware, + e.g.: + + ```python + from starlette.types import ASGIApp, Receive, Scope, Send + + class MyMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ... + await self.app(scope, receive, send) + ``` + + * A middleware function that takes request and next callable and returns a response, + e.g.: + + ```python + from fastapi import Request, Response + + async def my_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + ... + return await call_next(request) + ``` + + """ + + def register_middleware( + self, + app_runner: BaseDeploymentAppRunner, + spec: MiddlewareSpec, + ) -> None: + """Register middleware with FastAPI. + + Args: + app_runner: Deployment app runner instance. + spec: Framework-agnostic middleware specification. + + Raises: + RuntimeError: If the adapter is not used with a FastAPI application. + """ + app = app_runner.asgi_app + + if not isinstance(app, FastAPI): + raise RuntimeError( + f"The {self.__class__.__name__} adapter must be used with a " + "FastAPI application" + ) + + middleware = self.resolve_middleware_handler(app_runner, spec) + + if spec.native: + if isinstance(middleware, type): + app.add_middleware( + middleware, # type: ignore[arg-type] + **spec.extra_kwargs, + ) + return + + app.add_middleware( + BaseHTTPMiddleware, + dispatch=middleware, + **spec.extra_kwargs, + ) + + if isinstance(middleware, type): + app.add_middleware( + middleware, # type: ignore[arg-type] + **spec.extra_kwargs, + ) + return + + # Convert the unified middleware to a FastAPI middleware class + app.add_middleware( + BaseHTTPMiddleware, + dispatch=middleware, + **spec.extra_kwargs, + ) diff --git a/src/zenml/deployers/server/fastapi/app.py b/src/zenml/deployers/server/fastapi/app.py new file mode 100644 index 00000000000..8a6091f1d26 --- /dev/null +++ b/src/zenml/deployers/server/fastapi/app.py @@ -0,0 +1,328 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""FastAPI application for running ZenML pipeline deployments.""" + +import os +from contextlib import asynccontextmanager +from genericpath import isdir, isfile +from typing import Any, AsyncGenerator, Dict, List, Optional, cast + +from anyio import to_thread +from asgiref.typing import ( + ASGIApplication, +) +from fastapi import ( + FastAPI, + HTTPException, + Request, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates + +from zenml import __version__ as zenml_version +from zenml.config import ( + AppExtensionSpec, + DeploymentDefaultEndpoints, + EndpointMethod, + EndpointSpec, + MiddlewareSpec, +) +from zenml.deployers.server.adapters import ( + EndpointAdapter, + MiddlewareAdapter, +) +from zenml.deployers.server.app import ( + BaseDeploymentAppRunner, + BaseDeploymentAppRunnerFlavor, +) +from zenml.deployers.server.fastapi import FastAPIDeploymentAppRunnerFlavor +from zenml.deployers.server.fastapi.adapters import ( + FastAPIEndpointAdapter, + FastAPIMiddlewareAdapter, +) +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class FastAPIDeploymentAppRunner(BaseDeploymentAppRunner): + """FastAPI deployment app runner.""" + + @property + def flavor(cls) -> "BaseDeploymentAppRunnerFlavor": + """Return the flavor associated with this deployment application runner. + + Returns: + The flavor associated with this deployment application runner. + """ + return FastAPIDeploymentAppRunnerFlavor() + + def _create_endpoint_adapter(self) -> EndpointAdapter: + """Create FastAPI endpoint adapter. + + Returns: + FastAPI endpoint adapter instance. + """ + return FastAPIEndpointAdapter() + + def _create_middleware_adapter(self) -> MiddlewareAdapter: + """Create FastAPI middleware adapter. + + Returns: + FastAPI middleware adapter instance. + """ + return FastAPIMiddlewareAdapter() + + def _build_cors_middleware(self) -> MiddlewareSpec: + """Get the CORS middleware. + + Returns: + The CORS middleware. + """ + return MiddlewareSpec( + middleware=CORSMiddleware, + extra_kwargs=dict( + allow_origins=self.settings.cors.allow_origins, + allow_credentials=self.settings.cors.allow_credentials, + allow_methods=self.settings.cors.allow_methods, + allow_headers=self.settings.cors.allow_headers, + ), + native=True, + ) + + def _get_dashboard_endpoints(self) -> List[EndpointSpec]: + """Get the dashboard endpoints specs. + + This is called if the dashboard files path is set to construct the + endpoints specs for the dashboard. + + Returns: + The dashboard endpoints specs. + + Raises: + ValueError: If the index HTML file is not found in the dashboard + files path. + """ + dashboard_files_path = self.dashboard_files_path() + if not dashboard_files_path or not os.path.isdir(dashboard_files_path): + return [] + + endpoints: List[EndpointSpec] = [] + + async def catch_invalid_api(invalid_api_path: str) -> None: + """Invalid API endpoint. + + All API endpoints that are not defined in the API routers will be + caught by this endpoint and will return a 404 error. + + Args: + invalid_api_path: Invalid API path. + + Raises: + HTTPException: 404 error. + """ + logger.debug(f"Invalid API path requested: {invalid_api_path}") + raise HTTPException(status_code=404) + + if self.settings.api_url_path: + endpoints.append( + EndpointSpec( + path=f"{self.settings.api_url_path}" + + "/{invalid_api_path:path}", + method=EndpointMethod.GET, + handler=catch_invalid_api, + native=True, + extra_kwargs=dict( + include_in_schema=False, + ), + ) + ) + + static_files = [] + static_directories = [] + index_html_path = None + for file in os.listdir(dashboard_files_path): + if file.startswith("__"): + logger.debug(f"Skipping private file: {file}") + continue + if file in ["index.html", "index.htm"]: + # this is served separately + index_html_path = os.path.join(dashboard_files_path, file) + continue + if isfile(os.path.join(dashboard_files_path, file)): + static_files.append(file) + elif isdir(os.path.join(dashboard_files_path, file)): + static_directories.append(file) + + if index_html_path is None: + raise ValueError( + f"Index HTML file not found in the dashboard files path: " + f"{dashboard_files_path}" + ) + + for static_dir in static_directories: + static_files_endpoint = StaticFiles( + directory=os.path.join(dashboard_files_path, static_dir), + check_dir=False, + ) + endpoints.append( + EndpointSpec( + path=f"/{static_dir}", + method=EndpointMethod.GET, + handler=static_files_endpoint, + native=True, + auth_required=False, + ) + ) + + templates = Jinja2Templates(directory=dashboard_files_path) + + async def catch_all_endpoint(request: Request, file_path: str) -> Any: + """Dashboard catch-all endpoint. + + Args: + request: Request object. + file_path: Path to a file in the dashboard root folder. + + Returns: + The files in the dashboard root directory. + """ + # some static files need to be served directly from the root dashboard + # directory + if file_path and file_path in static_files: + logger.debug(f"Returning static file: {file_path}") + full_path = os.path.join(dashboard_files_path, file_path) + return FileResponse(full_path) + + # everything else is directed to the index.html file that hosts the + # single-page application - this is to support client-side routing + return templates.TemplateResponse( + "index.html", + dict( + request=request, + service_info=self.service.get_service_info().model_dump(), + ), + ) + + endpoints.append( + EndpointSpec( + path="/{file_path:path}", + method=EndpointMethod.GET, + handler=catch_all_endpoint, + native=True, + auth_required=False, + extra_kwargs=dict( + include_in_schema=False, + ), + ), + ) + + return endpoints + + def error_handler(self, request: Request, exc: ValueError) -> JSONResponse: + """FastAPI error handler. + + Args: + request: The request. + exc: The exception. + + Returns: + The error response. + """ + logger.error("Error in request: %s", exc) + return JSONResponse(status_code=500, content={"detail": str(exc)}) + + def build( + self, + middlewares: List[MiddlewareSpec], + endpoints: List[EndpointSpec], + extensions: List[AppExtensionSpec], + ) -> ASGIApplication: + """Build the FastAPI app for the deployment. + + Args: + middlewares: The middleware to register. + endpoints: The endpoints to register. + extensions: The extensions to install. + + Returns: + The configured FastAPI application instance. + """ + title = ( + self.settings.app_title + or f"ZenML Pipeline Deployment {self.deployment.name}" + ) + description = ( + self.settings.app_description + or f"ZenML pipeline deployment server for the " + f"{self.deployment.name} deployment" + ) + docs_url_path: Optional[str] = None + redoc_url_path: Optional[str] = None + if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.DOCS): + docs_url_path = self.settings.docs_url_path + if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.REDOC): + redoc_url_path = self.settings.redoc_url_path + + fastapi_kwargs: Dict[str, Any] = dict( + title=title, + description=description, + version=self.settings.app_version + if self.settings.app_version is not None + else zenml_version, + root_path=self.settings.root_url_path, + docs_url=docs_url_path, + redoc_url=redoc_url_path, + lifespan=self.lifespan, + ) + fastapi_kwargs.update(self.settings.app_kwargs) + + asgi_app = FastAPI(**fastapi_kwargs) + + # Save this so it's available for the middleware, endpoint adapters and + # extensions + self._asgi_app = cast(ASGIApplication, asgi_app) + + # Bind the app runner to the app state + asgi_app.state.app_runner = self + asgi_app.exception_handler(Exception)(self.error_handler) + + self.register_middlewares(*middlewares) + self.register_endpoints(*endpoints) + self.install_extensions(*extensions) + + return self._asgi_app + + @asynccontextmanager + async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: + """Manage the deployment application lifespan. + + Args: + app: The FastAPI application instance being deployed. + + Yields: + None: Control is handed back to FastAPI once initialization completes. + """ + # Set the maximum number of worker threads + to_thread.current_default_thread_limiter().total_tokens = ( + self.settings.thread_pool_size + ) + + self.startup() + + yield + + self.shutdown() diff --git a/src/zenml/deployers/server/models.py b/src/zenml/deployers/server/models.py index 6b341d0de7e..b2afd395530 100644 --- a/src/zenml/deployers/server/models.py +++ b/src/zenml/deployers/server/models.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +14,15 @@ """FastAPI application models.""" from datetime import datetime -from typing import TYPE_CHECKING, Annotated, Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional from uuid import UUID -from pydantic import BaseModel, Field, WithJsonSchema +from pydantic import BaseModel, Field from zenml.logger import get_logger logger = get_logger(__name__) -if TYPE_CHECKING: - from zenml.deployers.server.service import PipelineDeploymentService - class DeploymentInvocationResponseMetadata(BaseModel): """Pipeline invoke response metadata model.""" @@ -110,6 +107,9 @@ class DeploymentInfo(BaseModel): id: UUID = Field(title="The ID of the deployment.") name: str = Field(title="The name of the deployment.") + auth_enabled: bool = Field( + title="Whether the deployment is authenticated." + ) class SnapshotInfo(BaseModel): @@ -121,6 +121,18 @@ class SnapshotInfo(BaseModel): ) +class AppInfo(BaseModel): + """App info model.""" + + app_runner_flavor: str + docs_url_path: str + redoc_url_path: str + invoke_url_path: str + health_url_path: str + info_url_path: str + metrics_url_path: str + + class ServiceInfo(BaseModel): """Service info model.""" @@ -133,6 +145,7 @@ class ServiceInfo(BaseModel): pipeline: PipelineInfo = Field( title="The pipeline of the pipeline service." ) + app: AppInfo = Field(title="The deployment application") total_executions: int = Field( title="The total number of pipeline executions." ) @@ -152,36 +165,3 @@ class ExecutionMetrics(BaseModel): last_execution_time: Optional[datetime] = Field( default=None, title="The time of the last pipeline execution." ) - - -def get_pipeline_invoke_models( - service: "PipelineDeploymentService", -) -> Tuple[Type[BaseModel], Type[BaseModel]]: - """Generate the request and response models for the pipeline invoke endpoint. - - Args: - service: The pipeline deployment service. - - Returns: - A tuple containing the request and response models. - """ - if TYPE_CHECKING: - # mypy has a difficult time with dynamic models, so we return something - # static for mypy to use - return BaseModel, BaseModel - - else: - - class PipelineInvokeRequest(BaseDeploymentInvocationRequest): - parameters: Annotated[ - service.input_model, - WithJsonSchema(service.input_schema, mode="validation"), - ] - - class PipelineInvokeResponse(BaseDeploymentInvocationResponse): - outputs: Annotated[ - Optional[Dict[str, Any]], - WithJsonSchema(service.output_schema, mode="serialization"), - ] - - return PipelineInvokeRequest, PipelineInvokeResponse diff --git a/src/zenml/deployers/server/service.py b/src/zenml/deployers/server/service.py index a1f734cb0fb..7161a97f96f 100644 --- a/src/zenml/deployers/server/service.py +++ b/src/zenml/deployers/server/service.py @@ -16,16 +16,26 @@ import contextvars import time import traceback +from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Any, Dict, Optional, Tuple, Type, Union -from uuid import UUID, uuid4 +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Dict, + Optional, + Tuple, + Type, +) +from uuid import uuid4 -from pydantic import BaseModel +from pydantic import BaseModel, WithJsonSchema import zenml.pipelines.run_utils as run_utils from zenml.client import Client from zenml.deployers.server import runtime from zenml.deployers.server.models import ( + AppInfo, BaseDeploymentInvocationRequest, BaseDeploymentInvocationResponse, DeploymentInfo, @@ -51,11 +61,14 @@ LocalOrchestrator, LocalOrchestratorConfig, ) -from zenml.pipelines.pipeline_definition import Pipeline from zenml.stack import Stack from zenml.steps.utils import get_unique_step_output_names from zenml.utils import env_utils, source_utils +if TYPE_CHECKING: + from zenml.deployers.server.app import BaseDeploymentAppRunner + from zenml.pipelines.pipeline_definition import Pipeline + logger = get_logger(__name__) @@ -108,22 +121,207 @@ def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: pass -class PipelineDeploymentService: - """Pipeline deployment service.""" +class BasePipelineDeploymentService(ABC): + """Abstract base class for pipeline deployment services. + + Subclasses must implement lifecycle management, execution, health, + and schema accessors. This contract enables swapping implementations + via import-source configuration without modifying the FastAPI app + wiring code. + """ + + def __init__( + self, app_runner: "BaseDeploymentAppRunner", **kwargs: Any + ) -> None: + """Initialize the deployment service. + + Args: + app_runner: The deployment application runner used with this service. + **kwargs: Additional keyword arguments for the deployment service. + + Raises: + RuntimeError: If snapshot cannot be loaded. + """ + self.app_runner = app_runner + self.deployment = app_runner.deployment + if self.deployment.snapshot is None: + raise RuntimeError("Deployment has no snapshot") + self.snapshot = self.deployment.snapshot + + @abstractmethod + def initialize(self) -> None: + """Initialize service resources and run init hooks. + + Raises: + Exception: If the service cannot be initialized. + """ + + @abstractmethod + def cleanup(self) -> None: + """Cleanup service resources and run cleanup hooks.""" - def __init__(self, deployment_id: Union[str, UUID]) -> None: - """Initialize service with minimal state. + @abstractmethod + def execute_pipeline( + self, request: BaseDeploymentInvocationRequest + ) -> BaseDeploymentInvocationResponse: + """Execute the deployment with the given parameters. Args: - deployment_id: The ID of the running deployment. + request: Runtime parameters supplied by the caller. + + Returns: + A BaseDeploymentInvocationResponse describing the execution result. + """ + + @abstractmethod + def get_service_info(self) -> ServiceInfo: + """Get service information. + + Returns: + A dictionary containing service information. + """ + + @abstractmethod + def get_execution_metrics(self) -> ExecutionMetrics: + """Return lightweight execution metrics for observability. + + Returns: + A dictionary containing execution metrics. + """ + + @abstractmethod + def health_check(self) -> None: + """Check service health. Raises: - RuntimeError: If the deployment or snapshot cannot be loaded. + RuntimeError: If the service is not healthy. """ - # Accept both str and UUID for flexibility - if isinstance(deployment_id, str): - deployment_id = UUID(deployment_id) + # ---------- + # Schemas and models for OpenAPI enrichment + # ---------- + + @property + def input_model( + self, + ) -> Type[BaseModel]: + """Construct a Pydantic model representing pipeline input parameters. + + Load the pipeline class from `pipeline_spec.source` and derive the + entrypoint signature types to create a dynamic Pydantic model + (extra='forbid') to use for parameter validation. + + Returns: + A Pydantic `BaseModel` subclass that validates the pipeline input + parameters. + + Raises: + RuntimeError: If the pipeline class cannot be loaded or if no + parameters model can be constructed for the pipeline. + """ + if ( + not self.snapshot.pipeline_spec + or not self.snapshot.pipeline_spec.source + ): + raise RuntimeError( + f"Snapshot `{self.snapshot.id}` is missing a " + "pipeline_spec.source; cannot build input model." + ) + + try: + pipeline_class: "Pipeline" = source_utils.load( + self.snapshot.pipeline_spec.source + ) + except Exception as e: + raise RuntimeError( + "Failed to load pipeline class from snapshot" + ) from e + + model = pipeline_class._compute_input_model() + if not model: + raise RuntimeError( + f"Failed to construct input model from pipeline " + f"`{self.snapshot.pipeline_configuration.name}`." + ) + return model + + @property + def input_schema(self) -> Dict[str, Any]: + """Return the JSON schema for pipeline input parameters. + + Returns: + The JSON schema for pipeline parameters. + + Raises: + RuntimeError: If the pipeline input schema is not available. + """ + if ( + self.snapshot.pipeline_spec + and self.snapshot.pipeline_spec.input_schema + ): + return self.snapshot.pipeline_spec.input_schema + # This should never happen, given that we check for this in the + # base deployer. + raise RuntimeError("The pipeline input schema is not available.") + + @property + def output_schema(self) -> Dict[str, Any]: + """Return the JSON schema for the pipeline outputs. + + Returns: + The JSON schema for the pipeline outputs. + + Raises: + RuntimeError: If the pipeline output schema is not available. + """ + if ( + self.snapshot.pipeline_spec + and self.snapshot.pipeline_spec.output_schema + ): + return self.snapshot.pipeline_spec.output_schema + # This should never happen, given that we check for this in the + # base deployer. + raise RuntimeError("The pipeline output schema is not available.") + + def get_pipeline_invoke_models( + self, + ) -> Tuple[Type[BaseModel], Type[BaseModel]]: + """Generate the request and response models for the pipeline invoke endpoint. + + Returns: + A tuple containing the request and response models. + """ + if TYPE_CHECKING: + # mypy has a difficult time with dynamic models, so we return something + # static for mypy to use + return BaseModel, BaseModel + + else: + + class PipelineInvokeRequest(BaseDeploymentInvocationRequest): + parameters: Annotated[ + self.input_model, + WithJsonSchema(self.input_schema, mode="validation"), + ] + + class PipelineInvokeResponse(BaseDeploymentInvocationResponse): + outputs: Annotated[ + Optional[Dict[str, Any]], + WithJsonSchema(self.output_schema, mode="serialization"), + ] + + return PipelineInvokeRequest, PipelineInvokeResponse + + +class PipelineDeploymentService(BasePipelineDeploymentService): + """Default pipeline deployment service implementation.""" + + def initialize(self) -> None: + """Initialize service with proper error handling. + + Raises: + Exception: If the service cannot be initialized. + """ self._client = Client() # Execution tracking @@ -143,25 +341,6 @@ def __init__(self, deployment_id: Union[str, UUID]) -> None: updated=datetime.now(), ) - logger.info("Loading pipeline snapshot configuration...") - - try: - self.deployment = self._client.zen_store.get_deployment( - deployment_id=deployment_id - ) - except Exception as e: - raise RuntimeError(f"Failed to load deployment: {e}") from e - - if self.deployment.snapshot is None: - raise RuntimeError("Deployment has no snapshot") - self.snapshot = self.deployment.snapshot - - def initialize(self) -> None: - """Initialize service with proper error handling. - - Raises: - Exception: If the service cannot be initialized. - """ try: # Execute init hook BaseOrchestrator.run_init_hook(self.snapshot) @@ -241,10 +420,13 @@ def get_service_info(self) -> ServiceInfo: A dictionary containing service information. """ uptime = time.time() - self.service_start_time + settings = self.app_runner.settings + api_urlpath = f"{self.app_runner.settings.root_url_path}{self.app_runner.settings.api_url_path}" return ServiceInfo( deployment=DeploymentInfo( id=self.deployment.id, name=self.deployment.name, + auth_enabled=self.deployment.auth_key is not None, ), snapshot=SnapshotInfo( id=self.snapshot.id, @@ -258,6 +440,15 @@ def get_service_info(self) -> ServiceInfo: input_schema=self.input_schema, output_schema=self.output_schema, ), + app=AppInfo( + app_runner_flavor=self.app_runner.flavor.name, + docs_url_path=settings.docs_url_path, + redoc_url_path=settings.redoc_url_path, + invoke_url_path=api_urlpath + settings.invoke_url_path, + health_url_path=api_urlpath + settings.health_url_path, + info_url_path=api_urlpath + settings.info_url_path, + metrics_url_path=api_urlpath + settings.metrics_url_path, + ), total_executions=self.total_executions, last_execution_time=self.last_execution_time, status="healthy", @@ -275,13 +466,9 @@ def get_execution_metrics(self) -> ExecutionMetrics: last_execution_time=self.last_execution_time, ) - def is_healthy(self) -> bool: - """Check service health. - - Returns: - True if the service is healthy, otherwise False. - """ - return True + def health_check(self) -> None: + """Check service health.""" + pass def _map_outputs( self, @@ -520,89 +707,3 @@ def _build_response( snapshot_name=self.snapshot.name, ), ) - - # ---------- - # Schemas and models for OpenAPI enrichment - # ---------- - - @property - def input_model( - self, - ) -> Type[BaseModel]: - """Construct a Pydantic model representing pipeline input parameters. - - Load the pipeline class from `pipeline_spec.source` and derive the - entrypoint signature types to create a dynamic Pydantic model - (extra='forbid') to use for parameter validation. - - Returns: - A Pydantic `BaseModel` subclass that validates the pipeline input - parameters. - - Raises: - RuntimeError: If the pipeline class cannot be loaded or if no - parameters model can be constructed for the pipeline. - """ - if ( - not self.snapshot.pipeline_spec - or not self.snapshot.pipeline_spec.source - ): - raise RuntimeError( - f"Snapshot `{self.snapshot.id}` is missing a " - "pipeline_spec.source; cannot build input model." - ) - - try: - pipeline_class: Pipeline = source_utils.load( - self.snapshot.pipeline_spec.source - ) - except Exception as e: - raise RuntimeError( - "Failed to load pipeline class from snapshot" - ) from e - - model = pipeline_class._compute_input_model() - if not model: - raise RuntimeError( - f"Failed to construct input model from pipeline " - f"`{self.snapshot.pipeline_configuration.name}`." - ) - return model - - @property - def input_schema(self) -> Dict[str, Any]: - """Return the JSON schema for pipeline input parameters. - - Returns: - The JSON schema for pipeline parameters. - - Raises: - RuntimeError: If the pipeline input schema is not available. - """ - if ( - self.snapshot.pipeline_spec - and self.snapshot.pipeline_spec.input_schema - ): - return self.snapshot.pipeline_spec.input_schema - # This should never happen, given that we check for this in the - # base deployer. - raise RuntimeError("The pipeline input schema is not available.") - - @property - def output_schema(self) -> Dict[str, Any]: - """Return the JSON schema for the pipeline outputs. - - Returns: - The JSON schema for the pipeline outputs. - - Raises: - RuntimeError: If the pipeline output schema is not available. - """ - if ( - self.snapshot.pipeline_spec - and self.snapshot.pipeline_spec.output_schema - ): - return self.snapshot.pipeline_spec.output_schema - # This should never happen, given that we check for this in the - # base deployer. - raise RuntimeError("The pipeline output schema is not available.") diff --git a/src/zenml/deployers/utils.py b/src/zenml/deployers/utils.py index 81c2600c0f3..74ffe88e8c3 100644 --- a/src/zenml/deployers/utils.py +++ b/src/zenml/deployers/utils.py @@ -14,13 +14,17 @@ """ZenML deployers utilities.""" import json -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID import jsonref import requests from zenml.client import Client +from zenml.config.deployment_settings import ( + DEFAULT_DEPLOYMENT_APP_INVOKE_URL_PATH, + DeploymentSettings, +) from zenml.config.step_configurations import Step from zenml.deployers.exceptions import ( DeploymentHTTPError, @@ -222,8 +226,15 @@ def invoke_deployment( f"Failed to serialize request data to JSON: {e}" ) + invoke_url_path = DEFAULT_DEPLOYMENT_APP_INVOKE_URL_PATH + if deployment.snapshot: + deployment_settings = ( + deployment.snapshot.pipeline_configuration.deployment_settings + ) + invoke_url_path = f"{deployment_settings.root_url_path}{deployment_settings.api_url_path}{deployment_settings.invoke_url_path}" + # Construct the invoke endpoint URL - invoke_url = deployment.url.rstrip("/") + "/invoke" + invoke_url = deployment.url.rstrip("/") + invoke_url_path # Prepare headers headers = { @@ -398,3 +409,35 @@ def deployment_snapshot_request_from_source_snapshot( pipeline_version_hash=source_snapshot.pipeline_version_hash, pipeline_spec=updated_pipeline_spec, ) + + +def load_deployment_requirements( + deployment_settings: DeploymentSettings, +) -> List[str]: + """Load the software requirements for a deployment. + + Args: + deployment_settings: The deployment settings for which to load the + software requirements. + + Returns: + The software requirements for the deployment. + + Raises: + RuntimeError: If the deployment app runner flavor cannot be loaded. + """ + from zenml.deployers.server.app import BaseDeploymentAppRunnerFlavor + + try: + deployment_app_runner_flavor = ( + BaseDeploymentAppRunnerFlavor.load_app_runner_flavor( + deployment_settings + ) + ) + except Exception as e: + raise RuntimeError( + f"Failed to load deployment app runner flavor from deployment " + f"settings: {e}" + ) from e + + return deployment_app_runner_flavor.requirements diff --git a/src/zenml/integrations/aws/deployers/aws_deployer.py b/src/zenml/integrations/aws/deployers/aws_deployer.py index 3eaa8023483..80c0f6fee60 100644 --- a/src/zenml/integrations/aws/deployers/aws_deployer.py +++ b/src/zenml/integrations/aws/deployers/aws_deployer.py @@ -44,9 +44,7 @@ DeploymentProvisionError, ) from zenml.deployers.server.entrypoint_configuration import ( - AUTH_KEY_OPTION, DEPLOYMENT_ID_OPTION, - PORT_OPTION, DeploymentEntrypointConfiguration, ) from zenml.enums import DeploymentStatus, StackComponentType @@ -255,8 +253,6 @@ def from_deployment( class AWSDeployer(ContainerizedDeployer): """Deployer responsible for deploying pipelines on AWS App Runner.""" - CONTAINER_REQUIREMENTS: List[str] = ["uvicorn", "fastapi"] - _boto_session: Optional[boto3.Session] = None _region: Optional[str] = None _app_runner_client: Optional[Any] = None @@ -1285,8 +1281,6 @@ def do_provision_deployment( arguments = DeploymentEntrypointConfiguration.get_entrypoint_arguments( **{ DEPLOYMENT_ID_OPTION: deployment.id, - PORT_OPTION: settings.port, - AUTH_KEY_OPTION: deployment.auth_key, } ) @@ -1312,8 +1306,11 @@ def do_provision_deployment( f"deploying to App Runner." ) + container_port = ( + snapshot.pipeline_configuration.deployment_settings.uvicorn_port + ) image_config: Dict[str, Any] = { - "Port": str(settings.port), + "Port": str(container_port), "StartCommand": " ".join(entrypoint + arguments), } @@ -1385,8 +1382,15 @@ def do_provision_deployment( "UnhealthyThreshold": settings.health_check_unhealthy_threshold, } + deployment_settings = ( + snapshot.pipeline_configuration.deployment_settings + ) + if settings.health_check_protocol.upper() == "HTTP": - health_check_configuration["Path"] = settings.health_check_path + root_path = deployment_settings.root_url_path + api_url_path = deployment_settings.api_url_path + health_check_path = f"{root_path}{api_url_path}{deployment_settings.health_url_path}" + health_check_configuration["Path"] = health_check_path network_configuration = { "IngressConfiguration": { diff --git a/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py b/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py index b21e5841b3d..95348f02e1f 100644 --- a/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py @@ -65,11 +65,6 @@ class AWSDeployerSettings(BaseDeployerSettings): description="Interval between health checks in seconds. Range: 1-20.", ) - health_check_path: str = Field( - default="/health", - description="Health check path for the App Runner service.", - ) - health_check_protocol: str = Field( default="TCP", description="Health check protocol. Options: 'TCP', 'HTTP'.", @@ -121,14 +116,6 @@ class AWSDeployerSettings(BaseDeployerSettings): description="Tags to apply to the App Runner service.", ) - # App Runner specific settings - port: int = Field( - default=8080, - ge=1, - le=65535, - description="Port on which the container listens for requests.", - ) - # Secret management configuration use_secrets_manager: bool = Field( default=True, diff --git a/src/zenml/integrations/gcp/deployers/gcp_deployer.py b/src/zenml/integrations/gcp/deployers/gcp_deployer.py index ce711f26bbb..62d6da6d3e2 100644 --- a/src/zenml/integrations/gcp/deployers/gcp_deployer.py +++ b/src/zenml/integrations/gcp/deployers/gcp_deployer.py @@ -45,9 +45,7 @@ DeploymentProvisionError, ) from zenml.deployers.server.entrypoint_configuration import ( - AUTH_KEY_OPTION, DEPLOYMENT_ID_OPTION, - PORT_OPTION, DeploymentEntrypointConfiguration, ) from zenml.enums import DeploymentStatus, StackComponentType @@ -249,8 +247,6 @@ def from_deployment( class GCPDeployer(ContainerizedDeployer, GoogleCredentialsMixin): """Deployer responsible for deploying pipelines on GCP Cloud Run.""" - CONTAINER_REQUIREMENTS: List[str] = ["uvicorn", "fastapi"] - _credentials: Optional[Any] = None _project_id: Optional[str] = None _cloud_run_client: Optional[run_v2.ServicesClient] = None @@ -1048,8 +1044,6 @@ def do_provision_deployment( arguments = DeploymentEntrypointConfiguration.get_entrypoint_arguments( **{ DEPLOYMENT_ID_OPTION: deployment.id, - PORT_OPTION: settings.port, - AUTH_KEY_OPTION: deployment.auth_key, } ) @@ -1073,13 +1067,16 @@ def do_provision_deployment( if settings.vpc_connector: vpc_access = run_v2.VpcAccess(connector=settings.vpc_connector) + container_port = ( + snapshot.pipeline_configuration.deployment_settings.uvicorn_port + ) container = run_v2.Container( image=image, command=entrypoint, args=arguments, env=env_vars, resources=resources, - ports=[run_v2.ContainerPort(container_port=settings.port)], + ports=[run_v2.ContainerPort(container_port=container_port)], ) template = run_v2.RevisionTemplate( diff --git a/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py b/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py index d8c4c660b37..aa2d87eb7e9 100644 --- a/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py +++ b/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py @@ -101,13 +101,6 @@ class GCPDeployerSettings(BaseDeployerSettings): description="Execution environment generation. Options: 'gen1', 'gen2'.", ) - port: int = Field( - default=8080, - ge=1, - le=65535, - description="Port on which the container listens for requests.", - ) - # Deployment configuration traffic_allocation: Dict[str, int] = Field( default_factory=lambda: {"LATEST": 100}, diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 050f338b9e6..eacb6fd24d1 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -26,7 +26,6 @@ ENV_ZENML_STEP_OPERATOR, handle_bool_env_var, ) -from zenml.deployers.server import runtime from zenml.enums import ExecutionMode, ExecutionStatus from zenml.environment import get_run_environment_dict from zenml.exceptions import RunInterruptedException, RunStoppedException @@ -424,6 +423,8 @@ def _run_step( step_run: The model of the current step run. force_write_logs: The context for the step logs. """ + from zenml.deployers.server import runtime + step_run_info = StepRunInfo( config=self._step.config, pipeline=self._snapshot.pipeline_configuration, diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 2f248389ade..eb67b6f2a63 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -37,7 +37,6 @@ ENV_ZENML_STEP_OPERATOR, handle_bool_env_var, ) -from zenml.deployers.server import runtime from zenml.enums import ArtifactSaveType from zenml.exceptions import StepInterfaceError from zenml.hooks.hook_validators import load_and_run_hook @@ -137,6 +136,8 @@ def run( Raises: BaseException: A general exception if the step fails. """ + from zenml.deployers.server import runtime + if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False): step_logging_enabled = False else: @@ -613,6 +614,8 @@ def _store_output_artifacts( Returns: The IDs of the published output artifacts. """ + from zenml.deployers.server import runtime + step_context = get_step_context() artifact_requests = [] diff --git a/src/zenml/pipelines/build_utils.py b/src/zenml/pipelines/build_utils.py index cd6f187a941..4e117561b89 100644 --- a/src/zenml/pipelines/build_utils.py +++ b/src/zenml/pipelines/build_utils.py @@ -407,6 +407,7 @@ def create_pipeline_build( entrypoint=build_config.entrypoint, extra_files=build_config.extra_files, code_repository=code_repository if pass_code_repo else None, + extra_requirements_files=build_config.extra_requirements_files, ) contains_code = include_files diff --git a/src/zenml/steps/__init__.py b/src/zenml/steps/__init__.py index fca2a278117..72d6cab12fc 100644 --- a/src/zenml/steps/__init__.py +++ b/src/zenml/steps/__init__.py @@ -26,8 +26,8 @@ decorator. """ -from zenml.config import ResourceSettings from zenml.steps.base_step import BaseStep +from zenml.config.resource_settings import ResourceSettings from zenml.steps.step_context import StepContext, get_step_context from zenml.steps.step_decorator import step diff --git a/src/zenml/utils/pipeline_docker_image_builder.py b/src/zenml/utils/pipeline_docker_image_builder.py index da000aebf15..3bc64d52021 100644 --- a/src/zenml/utils/pipeline_docker_image_builder.py +++ b/src/zenml/utils/pipeline_docker_image_builder.py @@ -84,6 +84,7 @@ def build_docker_image( entrypoint: Optional[str] = None, extra_files: Optional[Dict[str, str]] = None, code_repository: Optional["BaseCodeRepository"] = None, + extra_requirements_files: Dict[str, List[str]] = {}, ) -> Tuple[str, Optional[str], Optional[str]]: """Builds (and optionally pushes) a Docker image to run a pipeline. @@ -102,6 +103,10 @@ def build_docker_image( content or a file path. code_repository: The code repository from which files will be downloaded. + extra_requirements_files: Extra requirements to install in the + Docker image. Each key is the name of a Python requirements file + to be created and the value is the list of requirements to be + installed. Returns: A tuple (image_digest, dockerfile, requirements): @@ -169,6 +174,7 @@ def build_docker_image( include_files, entrypoint, extra_files, + extra_requirements_files, ] ) @@ -276,6 +282,7 @@ def build_docker_image( docker_settings=docker_settings, stack=stack, code_repository=code_repository, + extra_requirements_files=extra_requirements_files, ) self._add_requirements_files( @@ -412,6 +419,7 @@ def gather_requirements_files( stack: "Stack", code_repository: Optional["BaseCodeRepository"] = None, log: bool = True, + extra_requirements_files: Dict[str, List[str]] = {}, ) -> List[Tuple[str, str, List[str]]]: """Gathers and/or generates pip requirements files. @@ -427,6 +435,10 @@ def gather_requirements_files( code_repository: The code repository from which files will be downloaded. log: If True, will log the requirements. + extra_requirements_files: Extra requirements to install in the + Docker image. Each key is the name of a Python requirements file + to be created and the value is the list of requirements to be + installed. Raises: RuntimeError: If the command to export the local python packages @@ -440,6 +452,7 @@ def gather_requirements_files( The files will be in the following order: - Packages installed in the local Python environment - Requirements defined by stack integrations + - Extra requirements files - Requirements defined by user integrations - Requirements exported from a pyproject.toml - User-defined requirements @@ -550,6 +563,17 @@ def gather_requirements_files( ", ".join(f"`{r}`" for r in stack_requirements_list), ) + for filename, requirements in extra_requirements_files.items(): + requirements_list = sorted(requirements) + requirements_file = "\n".join(requirements_list) + requirements_files.append((filename, requirements_file, [])) + if log: + logger.info( + "- Including extra requirements from file `%s`: %s", + filename, + ", ".join(f"`{r}`" for r in requirements_list), + ) + # Generate requirements file for all required integrations integration_requirements = set( itertools.chain.from_iterable( diff --git a/src/zenml/utils/settings_utils.py b/src/zenml/utils/settings_utils.py index c9254bd899d..ece67b769e3 100644 --- a/src/zenml/utils/settings_utils.py +++ b/src/zenml/utils/settings_utils.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, Sequence, Type from zenml.config.constants import ( + DEPLOYMENT_SETTINGS_KEY, DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY, ) @@ -129,11 +130,16 @@ def get_general_settings() -> Dict[str, Type["BaseSettings"]]: Returns: Dictionary mapping general settings keys to their type. """ - from zenml.config import DockerSettings, ResourceSettings + from zenml.config import ( + DeploymentSettings, + DockerSettings, + ResourceSettings, + ) return { DOCKER_SETTINGS_KEY: DockerSettings, RESOURCE_SETTINGS_KEY: ResourceSettings, + DEPLOYMENT_SETTINGS_KEY: DeploymentSettings, } diff --git a/src/zenml/zen_server/middleware.py b/src/zenml/zen_server/middleware.py index be5316a8b31..7e61059d7e6 100644 --- a/src/zenml/zen_server/middleware.py +++ b/src/zenml/zen_server/middleware.py @@ -302,7 +302,7 @@ async def set_secure_headers(request: Request, call_next: Any) -> Any: ): return response - secure_headers().framework.fastapi(response) + await secure_headers().set_headers_async(response) return response diff --git a/src/zenml/zen_server/secure_headers.py b/src/zenml/zen_server/secure_headers.py index 17db00799a9..8cff2187182 100644 --- a/src/zenml/zen_server/secure_headers.py +++ b/src/zenml/zen_server/secure_headers.py @@ -71,23 +71,17 @@ def initialize_secure_headers() -> None: if isinstance(config.secure_headers_xfo, str): xfo.set(config.secure_headers_xfo) - xxp: Optional[secure.XXSSProtection] = None - if config.secure_headers_xxp: - xxp = secure.XXSSProtection() - if isinstance(config.secure_headers_xxp, str): - xxp.set(config.secure_headers_xxp) - csp: Optional[secure.ContentSecurityPolicy] = None if config.secure_headers_csp: csp = secure.ContentSecurityPolicy() if isinstance(config.secure_headers_csp, str): csp.set(config.secure_headers_csp) - content: Optional[secure.XContentTypeOptions] = None + xcto: Optional[secure.XContentTypeOptions] = None if config.secure_headers_content: - content = secure.XContentTypeOptions() + xcto = secure.XContentTypeOptions() if isinstance(config.secure_headers_content, str): - content.set(config.secure_headers_content) + xcto.set(config.secure_headers_content) referrer: Optional[secure.ReferrerPolicy] = None if config.secure_headers_referrer: @@ -105,15 +99,18 @@ def initialize_secure_headers() -> None: if config.secure_headers_permissions: permissions = secure.PermissionsPolicy() if isinstance(config.secure_headers_permissions, str): - permissions.value = config.secure_headers_permissions + # This one is special, because it doesn't allow setting the + # value as a string, but rather as a list of directives, so we + # hack our way around it by setting the private _default_value + # attribute. + permissions._default_value = config.secure_headers_permissions _secure_headers = secure.Secure( server=server, hsts=hsts, xfo=xfo, - xxp=xxp, csp=csp, - content=content, + xcto=xcto, referrer=referrer, cache=cache, permissions=permissions, diff --git a/tests/integration/functional/deployers/server/test_app_endpoints.py b/tests/integration/functional/deployers/server/test_app_endpoints.py deleted file mode 100644 index 4a4272e0650..00000000000 --- a/tests/integration/functional/deployers/server/test_app_endpoints.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) ZenML GmbH 2025. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Integration tests for FastAPI deployment application endpoints.""" - -import importlib -from types import ModuleType, SimpleNamespace -from typing import Generator, Optional, Tuple -from uuid import uuid4 - -import pytest -from fastapi.testclient import TestClient -from pydantic import BaseModel - -import zenml.deployers.server.app as deployment_app -from zenml.deployers.server.models import ( - BaseDeploymentInvocationRequest, - BaseDeploymentInvocationResponse, - DeploymentInfo, - DeploymentInvocationResponseMetadata, - ExecutionMetrics, - PipelineInfo, - ServiceInfo, - SnapshotInfo, -) - - -class MockWeatherRequest(BaseModel): - """Mock Pydantic model for testing.""" - - city: str - temperature: int = 20 - - -class StubDeploymentService: - """Stub service implementing the interface used by the FastAPI app.""" - - def __init__(self, deployment_id: str) -> None: - """Initialize the stub service. - - Args: - snapshot_id: The ID of the snapshot to use for the service. - """ - self._healthy = True - self.initialized = False - self.cleaned_up = False - self.last_request: Optional[BaseDeploymentInvocationRequest] = None - self.input_schema = { - "type": "object", - "properties": {"city": {"type": "string"}}, - } - self.output_schema = { - "type": "object", - "properties": {"result": {"type": "string"}}, - } - self.snapshot = SimpleNamespace( - id=uuid4(), - name="snapshot", - pipeline_configuration=SimpleNamespace( - name="test_pipeline", - environment={}, - init_hook_source=None, - init_hook_kwargs=None, - cleanup_hook_source=None, - ), - pipeline_spec=SimpleNamespace( - parameters={"city": "London"}, - input_schema=self.input_schema, - output_schema=self.output_schema, - ), - ) - self.deployment = SimpleNamespace( - id=uuid4(), - name="deployment", - snapshot=self.snapshot, - ) - - @property - def input_model(self) -> type[BaseModel]: # noqa: D401 - """Expose the request model expected by the service. - - Returns: - The request model expected by the service. - """ - - return MockWeatherRequest - - def initialize(self) -> None: # noqa: D401 - """Mark the service as initialized for verification in tests.""" - - self.initialized = True - - def cleanup(self) -> None: # noqa: D401 - """Mark the service as cleaned up for shutdown assertions.""" - - self.cleaned_up = True - - def is_healthy(self) -> bool: # noqa: D401 - """Return the current health flag used by tests.""" - - return self._healthy - - def set_health(self, healthy: bool) -> None: # noqa: D401 - """Set the health of the service. - - Args: - healthy: The health of the service. - """ - self._healthy = healthy - - def get_service_info(self) -> ServiceInfo: # noqa: D401 - """Retrieve public metadata describing the stub deployment.""" - - return ServiceInfo( - deployment=DeploymentInfo( - id=self.deployment.id, - name=self.deployment.name, - ), - snapshot=SnapshotInfo( - id=self.snapshot.id, name=self.snapshot.name - ), - pipeline=PipelineInfo( - name=self.snapshot.pipeline_configuration.name, - parameters=self.snapshot.pipeline_spec.parameters, - input_schema=self.input_schema, - output_schema=self.output_schema, - ), - total_executions=1, - last_execution_time=None, - status="healthy" if self._healthy else "unhealthy", - uptime=1.0, - ) - - def get_execution_metrics(self) -> ExecutionMetrics: # noqa: D401 - """Return execution metrics describing recent pipeline activity.""" - - return ExecutionMetrics(total_executions=1, last_execution_time=None) - - def execute_pipeline( - self, request: BaseDeploymentInvocationRequest - ) -> BaseDeploymentInvocationResponse: # noqa: D401 - """Execute the pipeline. - - Args: - request: The request to execute the pipeline. - - Returns: - The response from the pipeline. - """ - self.last_request = request - return BaseDeploymentInvocationResponse( - success=True, - outputs={"result": "ok"}, - execution_time=0.5, - metadata=DeploymentInvocationResponseMetadata( - deployment_id=self.deployment.id, - deployment_name=self.deployment.name, - pipeline_name="test_pipeline", - run_id=None, - run_name=None, - parameters_used=request.parameters.model_dump(), - snapshot_id=self.snapshot.id, - snapshot_name=self.snapshot.name, - ), - error=None, - ) - - -@pytest.fixture -def client_service_pair( - monkeypatch: pytest.MonkeyPatch, -) -> Generator[ - Tuple[TestClient, StubDeploymentService, ModuleType], None, None -]: - """Provide a fresh FastAPI client and stub service per test. - - Args: - monkeypatch: The monkeypatch fixture. - - Yields: - A tuple containing the FastAPI client, the stub service, and the reloaded app. - """ - reloaded_app = importlib.reload(deployment_app) - service = StubDeploymentService(str(uuid4())) - - monkeypatch.setenv("ZENML_DEPLOYMENT_ID", str(service.deployment.id)) - monkeypatch.delenv("ZENML_DEPLOYMENT_TEST_MODE", raising=False) - - def _service_factory(_: str) -> StubDeploymentService: - """Factory function for creating a stub service. - - Args: - _: The snapshot ID to use for the service. - - Returns: - The stub service. - """ - return service - - monkeypatch.setattr( - reloaded_app, - "PipelineDeploymentService", - _service_factory, - ) - - with TestClient(reloaded_app.app) as client: - yield client, service, reloaded_app - - -class TestFastAPIAppEndpoints: - """Integration tests for FastAPI application endpoints.""" - - def test_root_endpoint( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Ensure the root endpoint renders the deployment overview.""" - client, service, _ = client_service_pair - response = client.get("/") - assert response.status_code == 200 - assert "ZenML Pipeline Deployment" in response.text - assert "test_pipeline" in response.text - assert service.initialized is True - - def test_health_endpoint_healthy( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Ensure the health endpoint returns OK for healthy services.""" - client, _, _ = client_service_pair - response = client.get("/health") - assert response.status_code == 200 - assert response.json() == "OK" - - def test_health_endpoint_unhealthy( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Return a 503 status when the service reports unhealthy.""" - client, service, _ = client_service_pair - service.set_health(False) - response = client.get("/health") - assert response.status_code == 503 - - def test_info_endpoint( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Expose pipeline and snapshot metadata via /info.""" - client, service, _ = client_service_pair - response = client.get("/info") - assert response.status_code == 200 - data = response.json() - assert data["pipeline"]["name"] == "test_pipeline" - assert data["pipeline"]["input_schema"] == service.input_schema - assert data["snapshot"]["name"] == "snapshot" - - def test_metrics_endpoint( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Surface execution metrics through the metrics endpoint.""" - client, _, _ = client_service_pair - response = client.get("/metrics") - assert response.status_code == 200 - data = response.json() - assert data["total_executions"] == 1 - assert data["last_execution_time"] is None - - def test_invoke_endpoint_success( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Propagate successful execution responses for valid payloads.""" - client, service, _ = client_service_pair - payload = {"parameters": {"city": "Paris", "temperature": 25}} - - response = client.post("/invoke", json=payload) - - assert response.status_code == 200 - body = response.json() - assert body["success"] is True - assert body["outputs"] == {"result": "ok"} - assert service.last_request.parameters.city == "Paris" - - def test_invoke_endpoint_execution_failure( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Propagate failure responses without raising errors.""" - client, service, module = client_service_pair - failure_response = BaseDeploymentInvocationResponse( - success=False, - outputs=None, - execution_time=0.1, - metadata=DeploymentInvocationResponseMetadata( - deployment_id=service.deployment.id, - deployment_name=service.deployment.name, - pipeline_name="test_pipeline", - run_id=None, - run_name=None, - parameters_used={}, - snapshot_id=service.snapshot.id, - snapshot_name=service.snapshot.name, - ), - error="Pipeline execution failed", - ) - - service.execute_pipeline = lambda request: failure_response - - response = client.post( - "/invoke", json={"parameters": {"city": "Paris"}} - ) - assert response.status_code == 200 - assert response.json()["success"] is False - - def test_cleanup_called_on_shutdown( - self, - monkeypatch: pytest.MonkeyPatch, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Trigger service cleanup when the application shuts down.""" - reloaded_app = importlib.reload(deployment_app) - service = StubDeploymentService(str(uuid4())) - monkeypatch.setenv("ZENML_DEPLOYMENT_ID", str(service.deployment.id)) - monkeypatch.setattr( - reloaded_app, - "PipelineDeploymentService", - lambda deployment_id: service, - ) - with TestClient(reloaded_app.app): - pass - - assert service.initialized is True - assert service.cleaned_up is True - - -class TestOpenAPIIntegration: - """Integration tests for OpenAPI schema installation.""" - - def test_openapi_includes_invoke_models( - self, - client_service_pair: Tuple[ - TestClient, StubDeploymentService, ModuleType - ], - ) -> None: - """Include invoke request / response models within the OpenAPI schema.""" - client, service, module = client_service_pair - schema = client.get("/openapi.json").json() - operation = schema["paths"]["/invoke"]["post"] - - request_schema = operation["requestBody"]["content"][ - "application/json" - ]["schema"] - if "$ref" in request_schema: - ref = request_schema["$ref"].split("/")[-1] - request_schema = schema["components"]["schemas"][ref] - - parameters_schema = request_schema["properties"]["parameters"] - assert parameters_schema["properties"]["city"]["type"] == "string" - - response_schema = operation["responses"]["200"]["content"][ - "application/json" - ]["schema"] - if "$ref" in response_schema: - ref = response_schema["$ref"].split("/")[-1] - response_schema = schema["components"]["schemas"][ref] - - outputs_schema = response_schema["properties"]["outputs"] - if "$ref" in outputs_schema: - ref = outputs_schema["$ref"].split("/")[-1] - outputs_schema = schema["components"]["schemas"][ref] - - assert outputs_schema["properties"]["result"]["type"] == "string" diff --git a/tests/unit/deployers/server/test_app.py b/tests/unit/deployers/server/test_app.py deleted file mode 100644 index eb22fc9ab41..00000000000 --- a/tests/unit/deployers/server/test_app.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) ZenML GmbH 2025. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Unit tests for deployment app functionality.""" - -from __future__ import annotations - -import asyncio -import json -from typing import cast -from uuid import uuid4 - -import pytest -from fastapi import FastAPI, HTTPException, Request -from fastapi.security import HTTPAuthorizationCredentials -from fastapi.testclient import TestClient -from pydantic import BaseModel -from pytest_mock import MockerFixture - -from zenml.deployers.server.app import ( - _build_invoke_router, - app, - get_pipeline_service, - lifespan, - runtime_error_handler, - value_error_handler, - verify_token, -) -from zenml.deployers.server.models import ( - BaseDeploymentInvocationResponse, - DeploymentInfo, - DeploymentInvocationResponseMetadata, - ExecutionMetrics, - PipelineInfo, - ServiceInfo, - SnapshotInfo, -) -from zenml.deployers.server.service import PipelineDeploymentService - - -class MockWeatherRequest(BaseModel): - """Mock Pydantic model for testing.""" - - city: str - temperature: int = 20 - - -@pytest.fixture -def mock_service(mocker: MockerFixture) -> PipelineDeploymentService: - """Mock pipeline deployment service configured for the app tests.""" - - service = cast( - PipelineDeploymentService, - mocker.MagicMock(spec=PipelineDeploymentService), - ) - snapshot_id = uuid4() - deployment_id = uuid4() - - service.input_model = MockWeatherRequest - service.is_healthy.return_value = True - service.input_schema = { - "type": "object", - "properties": {"city": {"type": "string"}}, - } - service.output_schema = { - "type": "object", - "properties": {"result": {"type": "string"}}, - } - - service.get_service_info.return_value = ServiceInfo( - deployment=DeploymentInfo(id=deployment_id, name="deployment"), - snapshot=SnapshotInfo(id=snapshot_id, name="snapshot"), - pipeline=PipelineInfo( - name="test_pipeline", - parameters={"city": "London"}, - input_schema=service.input_schema, - output_schema=service.output_schema, - ), - total_executions=3, - last_execution_time=None, - status="healthy", - uptime=12.34, - ) - service.get_execution_metrics.return_value = ExecutionMetrics( - total_executions=3, - last_execution_time=None, - ) - service.execute_pipeline.return_value = BaseDeploymentInvocationResponse( - success=True, - outputs={"result": "ok"}, - execution_time=0.5, - metadata=DeploymentInvocationResponseMetadata( - deployment_id=deployment_id, - deployment_name="deployment", - pipeline_name="test_pipeline", - run_id=None, - run_name=None, - parameters_used={"city": "Paris", "temperature": 25}, - snapshot_id=snapshot_id, - snapshot_name="snapshot", - ), - error=None, - ) - return service - - -class TestDeploymentAppRoutes: - """Test FastAPI app routes.""" - - def test_root_endpoint( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Root endpoint returns HTML with pipeline information.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - with TestClient(app) as client: - response = client.get("/") - - assert response.status_code == 200 - assert response.headers["content-type"].startswith("text/html") - assert "ZenML Pipeline Deployment" in response.text - assert "test_pipeline" in response.text - - def test_health_endpoint( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Health endpoint returns OK when service is healthy.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - with TestClient(app) as client: - response = client.get("/health") - - assert response.status_code == 200 - assert response.json() == "OK" - - def test_health_endpoint_unhealthy( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Health endpoint raises when service reports unhealthy state.""" - mock_service.is_healthy.return_value = False - - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - with TestClient(app) as client: - response = client.get("/health") - - assert response.status_code == 503 - assert response.json()["detail"] == "Service is unhealthy" - - def test_info_endpoint( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Info endpoint returns service metadata.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - with TestClient(app) as client: - response = client.get("/info") - - assert response.status_code == 200 - data = response.json() - assert data["pipeline"]["name"] == "test_pipeline" - assert data["pipeline"]["parameters"] == {"city": "London"} - assert data["status"] == "healthy" - assert data["snapshot"]["name"] == "snapshot" - - def test_metrics_endpoint( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Metrics endpoint exposes execution metrics.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - with TestClient(app) as client: - response = client.get("/metrics") - - assert response.status_code == 200 - data = response.json() - assert data["total_executions"] == 3 - assert data["last_execution_time"] is None - - def test_info_endpoint_includes_schemas( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Info endpoint includes input/output schemas.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - with TestClient(app) as client: - response = client.get("/info") - - data = response.json() - assert data["pipeline"]["input_schema"] == mock_service.input_schema - assert data["pipeline"]["output_schema"] == mock_service.output_schema - - def test_get_pipeline_service_returns_current_instance( - self, - mock_service: PipelineDeploymentService, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """Ensure get_pipeline_service exposes the underlying instance.""" - monkeypatch.setattr( - "zenml.deployers.server.app._service", mock_service - ) - assert get_pipeline_service() is mock_service - - -class TestDeploymentAppInvoke: - """Test pipeline invocation via FastAPI.""" - - def test_invoke_endpoint_executes_service( - self, mock_service: PipelineDeploymentService - ) -> None: - """Invoke router validates payloads and calls the service.""" - fast_app = FastAPI() - fast_app.include_router(_build_invoke_router(mock_service)) - - with TestClient(fast_app) as client: - payload = {"parameters": {"city": "Paris", "temperature": 25}} - response = client.post("/invoke", json=payload) - - assert response.status_code == 200 - assert response.json()["success"] is True - mock_service.execute_pipeline.assert_called_once() - request_arg = mock_service.execute_pipeline.call_args.args[0] - assert request_arg.parameters.city == "Paris" - assert request_arg.skip_artifact_materialization is False - - def test_invoke_endpoint_validation_error( - self, mock_service: PipelineDeploymentService - ) -> None: - """Invalid payloads trigger validation errors.""" - fast_app = FastAPI() - fast_app.include_router(_build_invoke_router(mock_service)) - - with TestClient(fast_app) as client: - response = client.post("/invoke", json={"parameters": {}}) - - assert response.status_code == 422 - mock_service.execute_pipeline.assert_not_called() - - def test_verify_token_with_auth_enabled( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Token verification when authentication is enabled.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_AUTH_KEY", "test-auth-key") - - credentials = HTTPAuthorizationCredentials( - scheme="Bearer", credentials="test-auth-key" - ) - assert verify_token(credentials) is None - - with pytest.raises(HTTPException): - verify_token( - HTTPAuthorizationCredentials( - scheme="Bearer", credentials="wrong" - ) - ) - - with pytest.raises(HTTPException): - verify_token(None) - - def test_verify_token_with_auth_disabled( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Token verification when authentication is disabled.""" - monkeypatch.delenv("ZENML_DEPLOYMENT_AUTH_KEY", raising=False) - assert verify_token(None) is None - - -class TestDeploymentAppLifecycle: - """Test app lifecycle management.""" - - def test_lifespan_test_mode(self, monkeypatch: pytest.MonkeyPatch) -> None: - """Lifespan exits early in test mode.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_TEST_MODE", "true") - - async def _run() -> None: - async with lifespan(app): - pass - - asyncio.run(_run()) - - def test_lifespan_normal_mode( - self, - monkeypatch: pytest.MonkeyPatch, - mocker: MockerFixture, - ) -> None: - """Lifespan initializes and cleans up service in normal mode.""" - monkeypatch.setenv("ZENML_DEPLOYMENT_ID", "test-deployment-id") - - mock_service = cast( - PipelineDeploymentService, - mocker.MagicMock(spec=PipelineDeploymentService), - ) - mock_service.input_model = MockWeatherRequest - mock_service.initialize = mocker.MagicMock() - mock_service.cleanup = mocker.MagicMock() - - mocker.patch( - "zenml.deployers.server.app.PipelineDeploymentService", - return_value=mock_service, - ) - mock_include = mocker.patch.object(app, "include_router") - - async def _run() -> None: - async with lifespan(app): - pass - - asyncio.run(_run()) - - mock_include.assert_called() - mock_service.initialize.assert_called_once() - mock_service.cleanup.assert_called_once() - - def test_lifespan_missing_snapshot_id( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Lifespan raises when no snapshot id is configured.""" - monkeypatch.delenv("ZENML_DEPLOYMENT_ID", raising=False) - - async def _run() -> None: - with pytest.raises(ValueError, match="ZENML_DEPLOYMENT_ID"): - async with lifespan(app): - pass - - asyncio.run(_run()) - - -class TestDeploymentAppErrorHandling: - """Test app error handling.""" - - def test_value_error_handler(self) -> None: - """ValueError exception handler returns 400 with message.""" - request = Request( - {"type": "http", "method": "POST", "url": "http://test"} - ) - error = ValueError("Test error") - - response = value_error_handler(request, error) - assert response.status_code == 400 - payload = json.loads(response.body) - assert payload["detail"] == "Test error" - - def test_runtime_error_handler(self) -> None: - """RuntimeError exception handler returns 500 with message.""" - request = Request( - {"type": "http", "method": "POST", "url": "http://test"} - ) - error = RuntimeError("Runtime error") - - response = runtime_error_handler(request, error) - assert response.status_code == 500 - payload = json.loads(response.body) - assert payload["detail"] == "Runtime error" - - -class TestBuildInvokeRouter: - """Test the invoke router building functionality.""" - - def test_build_invoke_router( - self, mock_service: PipelineDeploymentService - ) -> None: - """Building the invoke router exposes /invoke route.""" - router = _build_invoke_router(mock_service) - - assert router is not None - routes = [route.path for route in router.routes] - assert "/invoke" in routes diff --git a/tests/unit/deployers/server/test_service.py b/tests/unit/deployers/server/test_service.py index 7878aa1781e..84f04ad5bb2 100644 --- a/tests/unit/deployers/server/test_service.py +++ b/tests/unit/deployers/server/test_service.py @@ -15,15 +15,25 @@ from __future__ import annotations -from contextlib import contextmanager from types import SimpleNamespace -from typing import Dict, Iterator -from uuid import UUID, uuid4 +from typing import Dict, List, Type +from uuid import uuid4 import pytest from pydantic import BaseModel from pytest_mock import MockerFixture +from zenml.config import ( + AppExtensionSpec, + DeploymentSettings, + EndpointSpec, + MiddlewareSpec, +) +from zenml.deployers.server.adapters import EndpointAdapter, MiddlewareAdapter +from zenml.deployers.server.app import ( + BaseDeploymentAppRunner, + BaseDeploymentAppRunnerFlavor, +) from zenml.deployers.server.models import BaseDeploymentInvocationRequest from zenml.deployers.server.service import PipelineDeploymentService @@ -44,6 +54,7 @@ def _make_snapshot() -> SimpleNamespace: init_hook_source=None, init_hook_kwargs={}, cleanup_hook_source=None, + deployment_settings=DeploymentSettings(), ) pipeline_spec = SimpleNamespace( parameters={"city": "London"}, @@ -67,14 +78,55 @@ def _make_snapshot() -> SimpleNamespace: def _make_deployment() -> SimpleNamespace: """Create a deployment stub with the attributes accessed by the service.""" return SimpleNamespace( - id=uuid4(), name="deployment", snapshot=_make_snapshot() + id=uuid4(), name="deployment", snapshot=_make_snapshot(), auth_key=None ) +class _DummyDeploymentAppRunnerFlavor(BaseDeploymentAppRunnerFlavor): + @property + def name(self) -> str: + return "dummy" + + @property + def implementation_class(self) -> Type[BaseDeploymentAppRunner]: + return _DummyDeploymentAppRunner + + +class _DummyDeploymentAppRunner(BaseDeploymentAppRunner): + @classmethod + def load_deployment(cls, deployment): + return deployment + + @property + def flavor(cls) -> "BaseDeploymentAppRunnerFlavor": + return _DummyDeploymentAppRunnerFlavor() + + def _create_endpoint_adapter(self) -> EndpointAdapter: + return None + + def _create_middleware_adapter(self) -> MiddlewareAdapter: + return None + + def _get_dashboard_endpoints(self) -> List[EndpointSpec]: + return [] + + def _build_cors_middleware(self) -> MiddlewareSpec: + return None + + def build( + self, + middlewares: List[MiddlewareSpec], + endpoints: List[EndpointSpec], + extensions: List[AppExtensionSpec], + ): + return None + + def _make_service_stub(mocker: MockerFixture) -> PipelineDeploymentService: """Create a service instance without running __init__ for isolated tests.""" deployment = _make_deployment() - service = PipelineDeploymentService.__new__(PipelineDeploymentService) + app_runner = _DummyDeploymentAppRunner(deployment) + service = PipelineDeploymentService(app_runner) service._client = mocker.MagicMock() service._orchestrator = mocker.MagicMock() mocker.patch.object( @@ -86,100 +138,9 @@ def _make_service_stub(mocker: MockerFixture) -> PipelineDeploymentService: service.service_start_time = 100.0 service.last_execution_time = None service.total_executions = 0 - service.deployment = deployment - service.snapshot = deployment.snapshot return service -def test_initialization_loads_deployment( - monkeypatch: pytest.MonkeyPatch, -) -> None: - """__init__ should load the deployment from the store.""" - deployment = _make_deployment() - - class DummyZenStore: - """In-memory zen store stub that records requested snapshot IDs.""" - - def __init__(self) -> None: - self.requested_snapshot_id: UUID | None = None - self.requested_deployment_id: UUID | None = None - - def get_snapshot(self, snapshot_id: UUID) -> SimpleNamespace: # noqa: D401 - """Return the stored snapshot and remember the requested ID.""" - - self.requested_snapshot_id = snapshot_id - return deployment.snapshot - - def get_deployment(self, deployment_id: UUID) -> SimpleNamespace: # noqa: D401 - """Return the stored deployment and remember the requested ID.""" - - self.requested_deployment_id = deployment_id - return deployment - - dummy_store = DummyZenStore() - - class DummyClient: - """Client stub providing access to the dummy zen store.""" - - def __init__(self) -> None: - self.zen_store = dummy_store - - monkeypatch.setattr("zenml.deployers.server.service.Client", DummyClient) - - service = PipelineDeploymentService(deployment.id) - - assert service.deployment is deployment - assert service.snapshot is deployment.snapshot - assert dummy_store.requested_deployment_id == deployment.id - assert dummy_store.requested_snapshot_id is None - - -def test_initialize_sets_up_orchestrator( - monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture -) -> None: - """initialize should activate integrations and build orchestrator.""" - deployment = _make_deployment() - - class DummyZenStore: - """Zen store stub that supplies the prepared snapshot.""" - - def get_snapshot(self, snapshot_id: UUID) -> SimpleNamespace: # noqa: D401 - return deployment.snapshot - - def get_deployment(self, deployment_id: UUID) -> SimpleNamespace: # noqa: D401 - return deployment - - class DummyClient: - """Client stub exposing only the attributes required by the service.""" - - def __init__(self) -> None: - self.zen_store = DummyZenStore() - - monkeypatch.setattr("zenml.deployers.server.service.Client", DummyClient) - - mock_orchestrator = mocker.MagicMock() - monkeypatch.setattr( - "zenml.deployers.server.service.SharedLocalOrchestrator", - mocker.MagicMock(return_value=mock_orchestrator), - ) - - @contextmanager - def _noop_env(_: object) -> Iterator[None]: - """Provide a no-op temporary environment context manager for tests.""" - - yield - - monkeypatch.setattr( - "zenml.deployers.server.service.env_utils.temporary_environment", - _noop_env, - ) - - service = PipelineDeploymentService(uuid4()) - service.initialize() - - assert service._orchestrator is mock_orchestrator - - def test_execute_pipeline_calls_subroutines(mocker: MockerFixture) -> None: """execute_pipeline should orchestrate helper methods and return response.""" service = _make_service_stub(mocker) diff --git a/tests/unit/deployers/server/test_service_outputs.py b/tests/unit/deployers/server/test_service_outputs.py index bf25a205aa4..97a3cb3d54f 100644 --- a/tests/unit/deployers/server/test_service_outputs.py +++ b/tests/unit/deployers/server/test_service_outputs.py @@ -14,14 +14,26 @@ """Unit tests for PipelineDeploymentService output mapping with in-memory mode.""" from types import SimpleNamespace -from typing import Generator +from typing import Generator, List from uuid import uuid4 import pytest from pydantic import BaseModel from pytest_mock import MockerFixture - +from typing_extensions import Type + +from zenml.config import ( + AppExtensionSpec, + DeploymentSettings, + EndpointSpec, + MiddlewareSpec, +) from zenml.deployers.server import runtime +from zenml.deployers.server.adapters import EndpointAdapter, MiddlewareAdapter +from zenml.deployers.server.app import ( + BaseDeploymentAppRunner, + BaseDeploymentAppRunnerFlavor, +) from zenml.deployers.server.models import BaseDeploymentInvocationRequest from zenml.deployers.server.service import PipelineDeploymentService @@ -42,6 +54,7 @@ def __init__(self) -> None: init_hook_source=None, init_hook_kwargs=None, cleanup_hook_source=None, + deployment_settings=DeploymentSettings(), ) self.pipeline_spec = SimpleNamespace( parameters={}, @@ -66,6 +79,47 @@ def __init__(self) -> None: self.name = "test-run" +class _DummyDeploymentAppRunnerFlavor(BaseDeploymentAppRunnerFlavor): + @property + def name(self) -> str: + return "dummy" + + @property + def implementation_class(self) -> Type[BaseDeploymentAppRunner]: + return _DummyDeploymentAppRunner + + +class _DummyDeploymentAppRunner(BaseDeploymentAppRunner): + @property + def flavor(cls) -> "BaseDeploymentAppRunnerFlavor": + return _DummyDeploymentAppRunnerFlavor() + + def _create_endpoint_adapter(self) -> EndpointAdapter: + return None + + def _create_middleware_adapter(self) -> MiddlewareAdapter: + return None + + def _get_dashboard_endpoints(self) -> List[EndpointSpec]: + return [] + + def _build_cors_middleware(self) -> MiddlewareSpec: + return None + + def build( + self, + middlewares: List[MiddlewareSpec], + endpoints: List[EndpointSpec], + extensions: List[AppExtensionSpec], + ): + return None + + +class _DummyOrchestrator: + def run(self, snapshot, stack, placeholder_run): # noqa: D401 + runtime.record_step_outputs("step1", {"result": "fast_value"}) + + @pytest.fixture(autouse=True) def clean_runtime_state() -> Generator[None, None, None]: """Ensure runtime state is reset before and after each test.""" @@ -111,9 +165,14 @@ def get_pipeline_run( return _DummyRun() monkeypatch.setattr("zenml.deployers.server.service.Client", DummyClient) + monkeypatch.setattr("zenml.deployers.server.app.Client", DummyClient) - service = PipelineDeploymentService(uuid4()) + service = PipelineDeploymentService( + _DummyDeploymentAppRunner(deployment.id) + ) + service.initialize() service.params_model = _DummyParams + service._orchestrator = _DummyOrchestrator() return service @@ -134,12 +193,6 @@ def test_service_captures_in_memory_outputs( lambda source_snapshot, deployment_parameters: SimpleNamespace(), ) - class _DummyOrchestrator: - def run(self, snapshot, stack, placeholder_run): # noqa: D401 - runtime.record_step_outputs("step1", {"result": "fast_value"}) - - service._orchestrator = _DummyOrchestrator() - request = BaseDeploymentInvocationRequest( parameters=_DummyParams(), skip_artifact_materialization=True, diff --git a/tests/unit/pipelines/test_build_utils.py b/tests/unit/pipelines/test_build_utils.py index 88fe10a128a..affa2570dc7 100644 --- a/tests/unit/pipelines/test_build_utils.py +++ b/tests/unit/pipelines/test_build_utils.py @@ -198,6 +198,7 @@ def test_build_uses_correct_settings(mocker, empty_pipeline): # noqa: F811 extra_files=build_config.extra_files, include_files=True, code_repository=None, + extra_requirements_files={}, ) assert build.pipeline.id == pipeline_id assert build.is_local is True