diff --git a/docs/beta_todos.md b/docs/beta_todos.md new file mode 100644 index 00000000000..a6af45f4831 --- /dev/null +++ b/docs/beta_todos.md @@ -0,0 +1,62 @@ +# Beta Hardening TODOs + +This is a living checklist for post‑beta hardening. All beta blockers are already implemented; the items below are for production readiness and scale. + +## Serving Runtime & Publishing + +- Multi‑worker scaling per process + - Option A (threads): Add a small ThreadPoolExecutor consuming the existing bounded queue; preserve backpressure and flush semantics. + - Option B (async): Introduce an asyncio loop + asyncio.Queue + async workers, once client/publish calls have async variants and we opt into async serving. + - Keep bounded queue, inline fallback on Full, and orderly shutdown (join/cancel with timeout). + +- Backpressure & batching + - Tune queue maxsize defaults; expose env knob `ZENML_RT_QUEUE_MAXSIZE`. + - Optional: micro‑batch compatible events for fewer round‑trips. + +- Circuit breaker refinements + - Distinguish network vs. logical errors for better decisions. + - Add optional cool‑down logs with guidance. + +## Artifact Write Semantics + +- Server‑side atomicity / compensation + - Align with server to provide atomic batch create or server‑side compensation. + - Client: switch from best‑effort retries to idempotent, category‑aware retries once server semantics are defined. + - Document consistency guarantees and failure behavior. + +## Request Parameter Schema & Safety + +- Parameter schema from entrypoint annotations + - Generate/derive expected types from pipeline entrypoint annotations (or compiled schema) rather than inferring from defaults. + - Add total payload size cap; add per‑type caps (e.g., list length, dict depth). + - Optional: strict mode that rejects unknown params rather than dropping. + +## Monitoring, Metrics, Health + +- Metrics enrichment + - Export runtime metrics to Prometheus (queue depth, cache hit rate, error rate, op latency histograms). + - Add per‑worker metrics if multi‑worker is enabled. + +- Health/liveness + - Expose background worker liveness/health via the service. + - Add simple self‑check endpoints and document alerts. + +## Memory & Resource Management + +- Process memory monitoring / limits + - Add process memory watchdog and log warnings; document recommended container limits. + - Add a user‑facing docs note about caching large artifacts and tuning `max_entries` accordingly. + +## Operational Docs & UX + +- Serving docs + - Add a prominent warning about memory usage for large cached artifacts and sizing `ZENML_RT_CACHE_MAX_ENTRIES`. + - Add examples for scaling processes/replicas and interpreting metrics. + +## Notes (Implemented in Beta) + +- Request param allowlist / type coercion / size caps +- Memory‑only isolation (instance‑scoped) and cleanup +- Bounded queue with inline fallback; race‑free cache sweep +- Graceful shutdown with timeout and final metrics +- Defensive artifact write behavior with minimal retries and response validation diff --git a/docs/book/how-to/serving/serving.md b/docs/book/how-to/serving/serving.md index 9ea0858e726..e14b595dbc9 100644 --- a/docs/book/how-to/serving/serving.md +++ b/docs/book/how-to/serving/serving.md @@ -1,42 +1,25 @@ --- title: Serving Pipelines -description: Millisecond-class pipeline execution over HTTP with intelligent run-only optimization and streaming. +description: Run pipelines as fast HTTP services with async serving by default and optional memory-only execution. --- # Serving Pipelines -ZenML Serving runs pipelines as ultra-fast FastAPI services, achieving millisecond-class latency through intelligent run-only execution. Perfect for real-time inference, AI agents, and interactive workflows. +ZenML Serving exposes a pipeline as a FastAPI service. In serving, execution uses a Realtime runtime with async server updates by default for low latency. You can optionally run memory-only for maximum speed. ## Why Serving vs. Orchestrators -- **Performance**: Millisecond-class latency with run-only execution (no DB/FS writes in fast mode) -- **Simplicity**: Call your pipeline via HTTP; get results or stream progress -- **Intelligence**: Automatically switches between tracking and run-only modes based on capture settings -- **Flexibility**: Optional run/step tracking with fine-grained capture policies +- Performance: Async serving with in-process caching for low latency. +- Simplicity: Invoke your pipeline over HTTP; get results or stream progress. +- Control: Single, typed `Capture` config to tune observability or enable memory-only. -Use orchestrators for scheduled, long-running, reproducible workflows; use Serving for real-time request/response. - -## How It Works - -**Run-Only Architecture** (for millisecond latency): -- **ServingOverrides**: Per-request parameter injection using ContextVar isolation -- **ServingBuffer**: In-memory step output handoff with no persistence -- **Effective Config**: Runtime configuration merging without model mutations -- **Skip I/O**: Bypasses all database writes and filesystem operations -- **Input Injection**: Upstream step outputs automatically injected as parameters - -**Full Tracking Mode** (when capture enabled): -- Traditional ZenML tracking with runs, steps, artifacts, and metadata -- Orchestrator-based execution with full observability - -The service automatically chooses the optimal execution mode based on your capture settings. +Use orchestrators for scheduled, reproducible workflows. Use Serving for request/response inference. ## Quickstart Prerequisites - A deployed pipeline; note its deployment UUID as `ZENML_PIPELINE_DEPLOYMENT_ID`. -- Python env with dev deps (as per CONTRIBUTING). Start the service @@ -47,7 +30,7 @@ export ZENML_SERVICE_PORT=8001 python -m zenml.deployers.serving.app ``` -Synchronous invocation +Invoke (sync) ```bash curl -s -X POST "http://localhost:8001/invoke" \ @@ -55,65 +38,53 @@ curl -s -X POST "http://localhost:8001/invoke" \ -d '{"parameters": {"your_param": "value"}}' ``` -## Performance Modes +## Capture (typed-only) -ZenML Serving automatically chooses the optimal execution mode: - -### Run-Only Mode (Millisecond Latency) - -Activated when `capture="none"` or no capture settings specified: +Configure capture at the pipeline decorator using a single, typed `Capture`: ```python -@pipeline(settings={"capture": "none"}) -def fast_pipeline(x: int) -> int: - return x * 2 -``` +from zenml import pipeline +from zenml.capture.config import Capture -**Optimizations**: -- ✅ Zero database writes -- ✅ Zero filesystem operations -- ✅ In-memory step output handoff -- ✅ Per-request parameter injection -- ✅ Effective configuration merging -- ✅ Multi-worker safe (ContextVar isolation) +@pipeline(capture=Capture()) # serving async by default +def serve_pipeline(...): + ... -**Use for**: Real-time inference, AI agents, interactive demos - -### Full Tracking Mode +@pipeline(capture=Capture(memory_only=True)) # serving only +def max_speed_pipeline(...): + ... +``` -Activated when capture settings specify tracking: +Options (observability only; do not affect dataflow): +- `code`: include code/source/docstrings in metadata (default True) +- `logs`: persist step logs (default True) +- `metadata`: publish run/step metadata (default True) +- `visualizations`: persist visualizations (default True) +- `metrics`: emit runtime metrics (default True) -```python -@pipeline(settings={"capture": "full"}) -def tracked_pipeline(x: int) -> int: - return x * 2 -``` +Notes +- Serving is async by default; there is no `flush_on_step_end` knob. +- `memory_only=True` is ignored outside serving with a warning. -**Features**: -- Complete run/step tracking -- Artifact persistence -- Metadata collection -- Dashboard integration +## Request Parameters -**Use for**: Experimentation, debugging, audit trails +Request JSON under `parameters` is merged into the effective step config in serving. Logged keys indicate which parameters were applied. ## Execution Modes -- **Sync**: `POST /invoke` waits for completion; returns results or error. -- **Async**: `POST /invoke?mode=async` returns a `job_id`; poll `GET /jobs/{job_id}`. -- **Streaming**: `GET /stream/{job_id}` (SSE) or `WebSocket /stream` to receive progress and completion events in real time. +- Sync: `POST /invoke` waits for completion; returns results or error. +- Async: `POST /invoke?mode=async` returns a `job_id`; poll `GET /jobs/{job_id}`. +- Streaming: `GET /stream/{job_id}` (SSE) or `WebSocket /stream` to stream progress. Async example ```bash -# Submit -JOB_ID=$(curl -s -X POST "http://localhost:8001/invoke?mode=async" -H "Content-Type: application/json" -d '{"parameters":{}}' | jq -r .job_id) - -# Poll +JOB_ID=$(curl -s -X POST "http://localhost:8001/invoke?mode=async" \ + -H "Content-Type: application/json" -d '{"parameters":{}}' | jq -r .job_id) curl -s "http://localhost:8001/jobs/$JOB_ID" ``` -SSE example +SSE ```bash curl -N -H "Accept: text/event-stream" "http://localhost:8001/stream/$JOB_ID" @@ -123,124 +94,19 @@ curl -N -H "Accept: text/event-stream" "http://localhost:8001/stream/$JOB_ID" - `/health`: Service health and uptime. - `/info`: Pipeline name, steps, parameter schema, deployment info. -- `/metrics`: Execution statistics (counts, averages). +- `/metrics`: Execution statistics (queue depth, cache hit rate, latencies when metrics enabled). - `/status`: Service configuration snapshot. -- `/invoke`: Execute (sync/async) with optional parameter overrides. +- `/invoke`: Execute (sync/async) with optional parameters. - `/jobs`, `/jobs/{id}`, `/jobs/{id}/cancel`: Manage async jobs. -- `/stream/{id}`: Server‑Sent Events stream for a job; `WebSocket /stream` for bidirectional. - -## Configuration - -Key environment variables - -- `ZENML_PIPELINE_DEPLOYMENT_ID`: Deployment UUID (required). -- `ZENML_SERVING_CAPTURE_DEFAULT`: Default capture mode (`none` for run-only, `full` for tracking). -- `ZENML_SERVICE_HOST` (default: `0.0.0.0`), `ZENML_SERVICE_PORT` (default: `8001`). -- `ZENML_LOG_LEVEL`: Logging verbosity. - -## Capture Policies - -Control what gets tracked per invocation: - -- **`none`**: Run-only mode, millisecond latency, no persistence -- **`metadata`**: Track runs/steps, no payload data -- **`full`**: Complete tracking with artifacts and metadata -- **`sampled`**: Probabilistic tracking for cost control -- **`errors_only`**: Track only failed executions - -Configuration locations: -- **Pipeline-level**: `@pipeline(settings={"capture": "none"})` -- **Request-level**: `{"capture_override": {"mode": "full"}}` -- **Environment**: `ZENML_SERVING_CAPTURE_DEFAULT=none` - -Precedence: Request > Pipeline > Environment > Default - -## Advanced Features - -### Input/Output Contracts - -Pipelines automatically expose their signature: - -```python -@pipeline -def my_pipeline(city: str, temperature: float) -> str: - return process_weather(city, temperature) - -# Automatic parameter schema: -# {"city": {"type": "str", "required": true}, -# "temperature": {"type": "float", "required": true}} -``` - -### Multi-Step Pipelines - -Step outputs automatically injected as inputs: - -```python -@step -def fetch_data(city: str) -> dict: - return {"weather": "sunny", "temp": 25} - -@step -def analyze_data(weather_data: dict) -> str: - return f"Analysis: {weather_data}" - -@pipeline -def weather_pipeline(city: str) -> str: - data = fetch_data(city) - return analyze_data(data) # weather_data auto-injected -``` - -### Response Building - -Only declared pipeline outputs returned: - -```python -@pipeline -def multi_output_pipeline(x: int) -> tuple[int, str]: - return x * 2, f"Result: {x}" - -# Response: {"outputs": {"output_0": 4, "output_1": "Result: 2"}} -``` - -## Testing & Local Dev - -Exercise endpoints locally: - -```bash -# Health check -curl http://localhost:8001/health - -# Pipeline info -curl http://localhost:8001/info - -# Execute with parameters -curl -X POST http://localhost:8001/invoke \ - -H "Content-Type: application/json" \ - -d '{"parameters": {"city": "Paris"}}' - -# Override capture mode -curl -X POST http://localhost:8001/invoke \ - -H "Content-Type: application/json" \ - -d '{"parameters": {"city": "Tokyo"}, "capture_override": {"mode": "full"}}' -``` +- `/stream/{id}`: Server‑Sent Events stream; `WebSocket /stream` for bidirectional. ## Troubleshooting -- **Missing deployment ID**: set `ZENML_PIPELINE_DEPLOYMENT_ID`. -- **Slow performance**: ensure `capture="none"` for run-only mode. -- **Import errors**: run-only mode bypasses some ZenML integrations that aren't needed for serving. -- **Memory leaks**: serving contexts are automatically cleared per request. -- **Multi-worker issues**: ContextVar isolation ensures thread safety. - -## Architecture Comparison +- Missing deployment ID: set `ZENML_PIPELINE_DEPLOYMENT_ID`. +- Slow responses: ensure you are in serving (async by default) or consider `Capture(memory_only=True)` for prototypes. +- Multi-worker/safety: Serving isolates request state; taps are cleared per request. -| Feature | Run-Only Mode | Full Tracking | -|---------|---------------|---------------| -| **Latency** | Milliseconds | Seconds | -| **DB Writes** | None | Full tracking | -| **FS Writes** | None | Artifacts | -| **Memory** | Minimal | Standard | -| **Debugging** | Limited | Complete | -| **Production** | ✅ Optimal | For experimentation | +## See Also -Choose run-only for production serving, full tracking for development and debugging. \ No newline at end of file +- Capture & Runtimes (advanced): serving defaults, toggles, memory-only behavior. +- Realtime Tuning: cache TTL/size, error reporting, and circuit breaker knobs. diff --git a/docs/book/serving/advanced/capture-and-runtime.md b/docs/book/serving/advanced/capture-and-runtime.md new file mode 100644 index 00000000000..6cad8a86b73 --- /dev/null +++ b/docs/book/serving/advanced/capture-and-runtime.md @@ -0,0 +1,94 @@ +--- +title: Capture & Execution Runtimes (Advanced) +--- + +# Capture & Execution Runtimes (Advanced) + +This page explains how capture options map to execution runtimes and how to tune them for production serving. + +## Execution Runtimes + +- DefaultStepRuntime (Batch) + - Standard ZenML execution: persists artifacts, creates runs and step runs, captures metadata/logs based on capture toggles. + - Used outside serving. + +- RealtimeStepRuntime (Serving, async by default) + - Optimized for low latency with async server updates and an in‑process cache for downstream loads. + - Tunables via env: `ZENML_RT_CACHE_TTL_SECONDS`, `ZENML_RT_CACHE_MAX_ENTRIES`, `ZENML_RT_ERR_REPORT_INTERVAL`, circuit breaker knobs (see Realtime Tuning page). + +- MemoryStepRuntime (Serving with memory_only) + - Pure in‑memory execution: no runs/steps/artifacts or server calls. + - Inter‑step data is exchanged via in‑process handles. + +## Capture API (typed only) +```python +from zenml.capture.config import Capture + +# Serving async (default) – explicit but not required +@pipeline(capture=Capture()) + +# Serving memory-only (no DB/artifacts) +@pipeline(capture=Capture(memory_only=True)) +def serve(...): + ... +``` + +Options: +- `memory_only` (serving only): in‑process handoff; no persistence. +- Observability toggles (affect only observability, not dataflow): + - `code`: include code/source/docstrings in metadata (default True) + - `logs`: persist step logs (default True) + - `metadata`: publish run/step metadata (default True) + - `visualizations`: persist visualizations (default True) + - `metrics`: emit runtime metrics (default True) + +## Serving Defaults + +- Serving uses the Realtime runtime and returns asynchronously by default. +- There is no `flush_on_step_end` knob; batch is blocking, serving is async. + +## Validation & Behavior + +- memory_only outside serving: ignored with a warning. +- Observability toggles never affect dataflow/caching, only what’s recorded. + +## Step Operators & Remote Execution + +Step operators and remote entrypoints derive behavior from context; no capture env propagation is required. + +## Memory-Only Internals (for deeper understanding) + +- Handle format: `mem:////` +- Memory runtime: + - `resolve_step_inputs`: constructs handles from `run_id` + substitutions. + - `load_input_artifact`: resolves handle to value from a thread-safe in-process store. + - `store_output_artifacts`: stores outputs back to the store; returns new handles for downstream steps. +- No server calls; no runs or artifacts are created. + +## Recipes + +- Low-latency serving (default): `@pipeline(capture=Capture())` +- Memory-only (stateless service): `@pipeline(capture=Capture(memory_only=True))` + +### Disable code capture (docstring/source) + +Code capture affects metadata only (not execution). You can disable it via capture: + +```python +from zenml.capture.config import Capture + +@pipeline(capture=Capture(code=False)) +def serve(...): + ... + +@pipeline(capture=Capture(code=False)) +def train(...): + ... +``` + +## FAQ + +- Does `code: false` break step execution? + - No. It only disables docstring/source capture. Steps still run normally. +- Can memory-only work with parallelism? + - Memory-only is per-process. For multi-process/multi-container setups, use persistence for cross-process data. diff --git a/docs/book/serving/advanced/realtime-tuning.md b/docs/book/serving/advanced/realtime-tuning.md new file mode 100644 index 00000000000..b11a1ab73f5 --- /dev/null +++ b/docs/book/serving/advanced/realtime-tuning.md @@ -0,0 +1,79 @@ +--- +title: Realtime Runtime Tuning & Circuit Breakers +--- + +# Realtime Runtime Tuning & Circuit Breakers + +This page documents advanced environment variables and metrics for tuning the Realtime runtime in production deployments. These knobs let you balance latency, throughput, and resilience under load. + +## When To Use This + +- High-QPS serving pipelines where latency and CPU efficiency matter +- Deployments needing stronger guardrails against cascading failures +- Teams instrumenting detailed metrics (cache hit rate, p95/p99 latencies) + +## Environment Variables + +Cache & Limits + +- `ZENML_RT_CACHE_TTL_SECONDS` (default: `60`) + - TTL for cached artifact values in seconds (in-process cache). +- `ZENML_RT_CACHE_MAX_ENTRIES` (default: `256`) + - LRU cache entry bound to prevent unbounded growth. + +Background Error Reporting + +- `ZENML_RT_ERR_REPORT_INTERVAL` (default: `15`) + - Minimum seconds between repeated background error logs (prevents log spam while maintaining visibility). + +Circuit Breaker (async → inline fallback) + +- `ZENML_RT_CB_ERR_THRESHOLD` (default: `0.1`) + - Error rate threshold to open the breaker (e.g., `0.1` = 10%). +- `ZENML_RT_CB_MIN_EVENTS` (default: `100`) + - Minimum number of publish events to evaluate before opening breaker. +- `ZENML_RT_CB_OPEN_SECONDS` (default: `300`) + - Duration (seconds) to keep breaker open; inline publishing is used while open. + +Notes + +- Serving uses the Realtime runtime by default. Outside serving, batch runtime is used. + +## Metrics & Observability + +`RealtimeStepRuntime.get_metrics()` returns a snapshot of: + +- Queue & Errors: `queued`, `processed`, `failed_total`, `queue_depth` +- Cache: `cache_hits`, `cache_misses`, `cache_hit_rate` +- Latency (op publish): `op_latency_p50_s`, `op_latency_p95_s`, `op_latency_p99_s` +- Config: `ttl_seconds`, `max_entries` + +Recommendation + +- Export metrics to your telemetry system (e.g., Prometheus) and alert on: + - Rising `failed_total` and sustained `queue_depth` + - Low `cache_hit_rate` + - High `op_latency_p95_s` / `op_latency_p99_s` + +## Recommended Production Defaults + +- Start conservative, then tune based on SLOs: + - `ZENML_RT_CACHE_TTL_SECONDS=60` + - `ZENML_RT_CACHE_MAX_ENTRIES=256` + - `ZENML_RT_ERR_REPORT_INTERVAL=15` + - `ZENML_RT_CB_ERR_THRESHOLD=0.1` + - `ZENML_RT_CB_MIN_EVENTS=100` + - `ZENML_RT_CB_OPEN_SECONDS=300` + +## Runbook (Common Scenarios) + +- High background errors: + - Check logs for circuit breaker events. If open, runtime will publish inline. Investigate upstream store or network failures. + - Consider temporarily reducing load or increasing `ZENML_RT_CB_OPEN_SECONDS` while recovering. + +- Rising queue depth / latency: + - Verify artifact store and API latency. + - Reduce cache TTL or size to reduce memory pressure; consider scaling workers. + +- Low cache hit rate: + - Check step dependencies and cache TTL; ensure downstream steps run in the same process to benefit from warm cache. diff --git a/docs/book/serving/overview.md b/docs/book/serving/overview.md new file mode 100644 index 00000000000..e1d4ef59f0d --- /dev/null +++ b/docs/book/serving/overview.md @@ -0,0 +1,70 @@ +--- +title: Pipeline Serving Overview +--- + +# Pipeline Serving Overview + +## What Is Pipeline Serving? + +- Purpose: Expose a ZenML pipeline as a low-latency service (e.g., via FastAPI) that executes steps on incoming requests and returns results. +- Value: Production-grade orchestration with simple capture options to balance latency, observability, and lineage. +- Modes by context: Batch outside serving (blocking), Realtime in serving (async), and pure in-memory serving for maximum speed. + +## Quick Start + +1) Define your pipeline +- Use your normal `@pipeline` and `@step` definitions. +- No serving-specific changes required. + +2) Choose capture only when you need to change defaults +- You don’t need to set capture in most cases: + - Batch (outside serving) is blocking. + - Serving is async by default. +- Optional tweaks (typed API only): + - Make it explicit: `@pipeline(capture=Capture())` + - Pure in-memory serving: `@pipeline(capture=Capture(memory_only=True))` + +3) Deploy the serving service with your preferred deployer and call the FastAPI endpoint. + +## Capture Essentials + +- Batch (outside serving) + - Blocking publishes; full persistence as configured by capture toggles. + +- Serving (inside serving) + - Async publishes by default with an in‑process cache; low latency. + +- Memory-only (serving only) + - Pure in‑memory execution: no runs/steps/artifacts or server calls; maximum speed. + - Outside serving, `memory_only=True` is ignored with a warning. + +## Where To Configure Capture + +- In code (typed only) + - `@pipeline(capture=Capture(...))` + - Options: `memory_only`, `code`, `logs`, `metadata`, `visualizations`, `metrics` + +## Best Practices + +- Most users (serving-ready) + - `@pipeline(capture=Capture())` + - Good balance of immediate response and production tracking. + +- Maximum speed (no tracking at all) + - `@pipeline(capture=Capture(memory_only=True))` + - Great for tests, benchmarks, or hot paths where lineage is not needed. + +- Compliance or rich lineage + - Use Batch (outside serving) where publishes are blocking by default. + +## FAQ (Essentials) + +- Does serving always create pipeline runs? + - Batch/Realtime: Yes. + - Memory-only (Realtime with `memory_only=True`): No; executes purely in memory. + +- Will serving block responses to flush tracking? + - No. Serving is async by default and returns immediately. + +- Is memory-only safe for production? + - Yes for stateless, speed-critical paths. Note: No lineage or persisted artifacts. diff --git a/docs/book/serving/toc.md b/docs/book/serving/toc.md new file mode 100644 index 00000000000..3dcc3d1ae74 --- /dev/null +++ b/docs/book/serving/toc.md @@ -0,0 +1,5 @@ +# Serving + +* [Pipeline Serving Overview](overview.md) +* Advanced + * [Capture Policy & Runtimes](advanced/capture-and-runtime.md) diff --git a/docs/book/toc.md b/docs/book/toc.md index eb7d4ebb15d..acce2d8b1b6 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -53,6 +53,9 @@ * [Models](how-to/models/models.md) * [Templates](how-to/templates/templates.md) * [Dashboard](how-to/dashboard/dashboard-features.md) +* Serving + * [Pipeline Serving Overview](serving/overview.md) + * [Capture Policy & Runtimes (Advanced)](serving/advanced/capture-and-runtime.md) * [Serving Pipelines](how-to/serving/serving.md) * [Pipeline Serving Capture Policies](how-to/serving/capture-policies.md) diff --git a/examples/serving/README.md b/examples/serving/README.md index c7e7062910a..ca5a8295741 100644 --- a/examples/serving/README.md +++ b/examples/serving/README.md @@ -1,338 +1,118 @@ # ZenML Pipeline Serving Examples -This directory contains examples demonstrating ZenML's new **run-only serving architecture** with millisecond-class latency for real-time inference and AI applications. +This directory contains examples that run pipelines as HTTP services using ZenML Serving. -## 🚀 **New Run-Only Architecture** +Highlights -ZenML Serving now automatically optimizes for performance: +- Async serving by default for low latency +- Optional memory-only execution via `Capture(memory_only=True)` +- Request parameter merging and streaming support -- **🏃‍♂️ Run-Only Mode**: Millisecond-class latency with zero DB/FS writes -- **🧠 Intelligent Switching**: Automatically chooses optimal execution mode -- **⚡ In-Memory Handoff**: Step outputs passed directly via serving buffer -- **🔄 Multi-Worker Safe**: ContextVar isolation for concurrent requests -- **📝 No Model Mutations**: Clean effective configuration merging +## Files -## 📁 Files +1. `weather_pipeline.py` – simple weather analysis +2. `chat_agent_pipeline.py` – conversational agent with streaming +3. `test_serving.py` – basic endpoint checks -1. **`weather_pipeline.py`** - Simple weather analysis with run-only optimization -2. **`chat_agent_pipeline.py`** - Streaming conversational AI with fast execution -3. **`test_serving.py`** - Test script to verify serving endpoints -4. **`README.md`** - This comprehensive guide +## Serving Modes (by context) -## 🎯 Examples Overview +- Batch (outside serving): blocking publishes; standard persistence +- Serving (default): async publishes with in‑process cache +- Memory-only (serving only): in‑process handoff; no DB/artifacts -### 1. Weather Agent Pipeline -- **Purpose**: Analyze weather for any city with AI recommendations -- **Mode**: Run-only optimization for millisecond response times -- **Features**: Automatic parameter injection, rule-based fallback -- **API**: Standard HTTP POST requests +## Quick Start: Weather Agent -### 2. Streaming Chat Agent Pipeline -- **Purpose**: Real-time conversational AI with streaming responses -- **Mode**: Run-only with optional streaming support -- **Features**: Token-by-token streaming, WebSocket support -- **API**: HTTP, WebSocket streaming, async jobs with SSE - -## 🏃‍♂️ **Run-Only vs Full Tracking** - -### Run-Only Mode (Default - Millisecond Latency) -```python -@pipeline # No capture settings = run-only mode -def fast_pipeline(city: str) -> str: - return analyze_weather(city) -``` - -**✅ Optimizations Active:** -- Zero database writes -- Zero filesystem operations -- In-memory step output handoff -- Per-request parameter injection -- Multi-worker safe execution - -### Full Tracking Mode (For Development) -```python -@pipeline(settings={"capture": "full"}) -def tracked_pipeline(city: str) -> str: - return analyze_weather(city) -``` - -**📊 Features Active:** -- Complete run/step tracking -- Artifact persistence -- Dashboard integration -- Debug information - -# 🚀 Quick Start Guide - -## Prerequisites - -```bash -# Install ZenML with serving support -pip install zenml - -# Optional: For LLM analysis (otherwise uses rule-based fallback) -export OPENAI_API_KEY=your_openai_api_key_here -pip install openai -``` - -## Example 1: Weather Agent (Run-Only Mode) - -### Step 1: Create and Deploy Pipeline +1) Create and deploy the pipeline ```bash python weather_pipeline.py ``` -**Expected Output:** -``` -🌤️ Creating Weather Agent Pipeline Deployment... -📦 Creating deployment for serving... -✅ Deployment ID: 12345678-1234-5678-9abc-123456789abc - -🚀 Start serving with: -export ZENML_PIPELINE_DEPLOYMENT_ID=12345678-1234-5678-9abc-123456789abc -python -m zenml.deployers.serving.app -``` - -### Step 2: Start Serving Service +2) Start the serving service ```bash -export ZENML_PIPELINE_DEPLOYMENT_ID=12345678-1234-5678-9abc-123456789abc +export ZENML_PIPELINE_DEPLOYMENT_ID= python -m zenml.deployers.serving.app ``` -**Service Configuration:** -- **Mode**: Run-only (millisecond latency) -- **Host**: `http://localhost:8000` -- **Optimizations**: All I/O operations bypassed - -### Step 3: Test Ultra-Fast Weather Analysis +3) Invoke ```bash -# Basic request (millisecond response time) curl -X POST "http://localhost:8000/invoke" \ -H "Content-Type: application/json" \ -d '{"parameters": {"city": "Paris"}}' - -# Response format: -{ - "success": true, - "outputs": { - "weather_analysis": "Weather in Paris is sunny with 22°C..." - }, - "execution_time": 0.003, # Milliseconds! - "metadata": { - "pipeline_name": "weather_agent_pipeline", - "parameters_used": {"city": "Paris"}, - "steps_executed": 3 - } -} -``` - -## Example 2: Streaming Chat Agent (Run-Only Mode) - -### Step 1: Create Chat Pipeline - -```bash -python chat_agent_pipeline.py -``` - -### Step 2: Start Serving Service - -```bash -export ZENML_PIPELINE_DEPLOYMENT_ID= -python -m zenml.deployers.serving.app -``` - -### Step 3: Test Ultra-Fast Chat - -#### Method A: Instant Response (Milliseconds) -```bash -curl -X POST "http://localhost:8000/invoke" \ - -H "Content-Type: application/json" \ - -d '{"parameters": {"message": "Hello!", "user_name": "Alice"}}' - -# Ultra-fast response: -{ - "success": true, - "outputs": {"chat_response": "Hello Alice! How can I help you today?"}, - "execution_time": 0.002 # Milliseconds! -} -``` - -#### Method B: Streaming Mode (Optional) -```bash -# Create async job -JOB_ID=$(curl -X POST 'http://localhost:8000/invoke?mode=async' \ - -H 'Content-Type: application/json' \ - -d '{"parameters": {"message": "Tell me about AI", "enable_streaming": true}}' \ - | jq -r .job_id) - -# Stream real-time results -curl -N "http://localhost:8000/stream/$JOB_ID" ``` -#### Method C: WebSocket Streaming -```bash -# Install wscat: npm install -g wscat -wscat -c ws://localhost:8000/stream - -# Send message: -{"parameters": {"message": "Hi there!", "user_name": "Alice", "enable_streaming": true}} -``` - -## 📊 Performance Comparison - -| Feature | Run-Only Mode | Full Tracking | -|---------|---------------|---------------| -| **Response Time** | 1-5ms | 100-500ms | -| **Throughput** | 1000+ RPS | 10-50 RPS | -| **Memory Usage** | Minimal | Standard | -| **DB Operations** | Zero | Full tracking | -| **FS Operations** | Zero | Artifact storage | -| **Use Cases** | Production serving | Development/debug | +Service defaults -## 🛠️ Advanced Configuration +- Host: `http://localhost:8000` +- Serving: async by default -### Performance Tuning +## Configuration ```bash -# Set capture mode explicitly -export ZENML_SERVING_CAPTURE_DEFAULT=none # Run-only mode - -# Multi-worker deployment -export ZENML_SERVICE_WORKERS=4 +export ZENML_PIPELINE_DEPLOYMENT_ID= python -m zenml.deployers.serving.app ``` -### Override Modes Per Request - -```bash -# Force tracking for a single request (slower but tracked) -curl -X POST "http://localhost:8000/invoke" \ - -H "Content-Type: application/json" \ - -d '{ - "parameters": {"city": "Tokyo"}, - "capture_override": {"mode": "full"} - }' -``` - -### Monitor Performance - -```bash -# Service health and performance -curl http://localhost:8000/health -curl http://localhost:8000/metrics - -# Pipeline information -curl http://localhost:8000/info -``` +To enable memory-only mode, set it in code: -## 🏗️ Architecture Deep Dive - -### Run-Only Execution Flow +```python +from zenml import pipeline +from zenml.capture.config import Capture -``` -Request → ServingOverrides → Effective Config → StepRunner → ServingBuffer → Response - (Parameters) (No mutations) (No I/O) (In-memory) (JSON) +@pipeline(capture=Capture(memory_only=True)) +def serve_max_speed(...): + ... ``` -1. **Request Arrives**: JSON parameters received -2. **ServingOverrides**: Per-request parameter injection via ContextVar -3. **Effective Config**: Runtime configuration merging (no model mutations) -4. **Step Execution**: Direct execution with serving buffer storage -5. **Response Building**: Only declared outputs returned as JSON +## Execution Flow (serving) -### Key Components +Request → Parameter merge → StepRunner → Response -- **`ServingOverrides`**: Thread-safe parameter injection -- **`ServingBuffer`**: In-memory step output handoff -- **Effective Configuration**: Runtime config merging without mutations -- **ContextVar Isolation**: Multi-worker safe execution +- Parameters under `parameters` are merged into step config. +- Serving is async; background updates do not block the response. -## 📚 API Reference +## API Reference -### Core Endpoints +Core endpoints -| Endpoint | Method | Purpose | Performance | -|----------|---------|---------|-------------| -| `/invoke` | POST | Execute pipeline | Milliseconds | -| `/health` | GET | Service health | Instant | -| `/info` | GET | Pipeline schema | Instant | -| `/metrics` | GET | Performance stats | Instant | +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `/invoke` | POST | Execute pipeline (sync or async) | +| `/health` | GET | Service health | +| `/info` | GET | Pipeline schema & deployment info | +| `/metrics` | GET | Runtime metrics (if enabled) | +| `/jobs`, `/jobs/{id}` | GET | Manage async jobs | +| `/stream/{id}` | GET | Server‑Sent Events stream | -### Request Format +Request format ```json { "parameters": { - "city": "string", - "temperature": "number", - "enable_streaming": "boolean" - }, - "capture_override": { - "mode": "none|metadata|full" + "city": "string" } } ``` -### Response Format +## Troubleshooting -```json -{ - "success": true, - "outputs": { - "output_name": "output_value" - }, - "execution_time": 0.003, - "metadata": { - "pipeline_name": "string", - "parameters_used": {}, - "steps_executed": 0 - } -} -``` - -## 🔧 Troubleshooting +- Missing deployment ID: set `ZENML_PIPELINE_DEPLOYMENT_ID`. +- Slow responses: serving is async by default; for prototypes consider `Capture(memory_only=True)`. +- Monitor: use `/metrics` for queue depth, cache hit rate, and latencies. -### Performance Issues -- ✅ **Ensure run-only mode**: No capture settings or `capture="none"` -- ✅ **Check environment**: `ZENML_SERVING_CAPTURE_DEFAULT=none` -- ✅ **Monitor metrics**: Use `/metrics` endpoint - -### Common Problems -- **Slow responses**: Verify run-only mode is active -- **Import errors**: Run-only mode bypasses unnecessary integrations -- **Memory leaks**: Serving contexts auto-cleared per request -- **Multi-worker issues**: ContextVar provides thread isolation - -### Debug Mode -```bash -# Enable full tracking for debugging -curl -X POST "http://localhost:8000/invoke" \ - -d '{"parameters": {...}, "capture_override": {"mode": "full"}}' -``` - -## 🎯 Production Deployment - -### Docker Example +## Docker ```dockerfile -FROM python:3.9-slim - -# Install ZenML +FROM python:3.11-slim RUN pip install zenml - -# Set serving configuration -ENV ZENML_SERVING_CAPTURE_DEFAULT=none ENV ZENML_SERVICE_HOST=0.0.0.0 ENV ZENML_SERVICE_PORT=8000 - -# Start serving CMD ["python", "-m", "zenml.deployers.serving.app"] ``` -### Kubernetes Example +## Kubernetes (snippet) ```yaml apiVersion: apps/v1 @@ -340,7 +120,7 @@ kind: Deployment metadata: name: zenml-serving spec: - replicas: 3 + replicas: 2 template: spec: containers: @@ -349,18 +129,7 @@ spec: env: - name: ZENML_PIPELINE_DEPLOYMENT_ID value: "your-deployment-id" - - name: ZENML_SERVING_CAPTURE_DEFAULT - value: "none" ports: - containerPort: 8000 ``` -## 🚀 Next Steps - -1. **Deploy Examples**: Try both weather and chat examples -2. **Measure Performance**: Use the `/metrics` endpoint -3. **Scale Up**: Deploy with multiple workers -4. **Monitor**: Integrate with your observability stack -5. **Optimize**: Fine-tune capture policies for your use case - -The new run-only architecture delivers production-ready performance for real-time AI applications! 🎉 \ No newline at end of file diff --git a/examples/serving/weather_pipeline.py b/examples/serving/weather_pipeline.py index 2c8f5a0cf14..ebc3960a79a 100644 --- a/examples/serving/weather_pipeline.py +++ b/examples/serving/weather_pipeline.py @@ -12,11 +12,13 @@ Perfect for real-time inference and AI applications. """ +import logging import os import random from typing import Dict from zenml import pipeline, step +from zenml.capture.config import Capture from zenml.config import DockerSettings # Import enums for type-safe capture mode configuration @@ -24,6 +26,9 @@ from zenml.config.resource_settings import ResourceSettings from zenml.steps.step_context import get_step_context +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + # Note: You can use either approach: # 1. String literals: "full", "metadata", "sampled", "errors_only", "none" # 2. Type-safe enums: CaptureMode.FULL, CaptureMode.METADATA, etc. @@ -70,7 +75,7 @@ def init_hook() -> PipelineState: return PipelineState() -@step +@step(enable_cache=False) def get_weather(city: str) -> Dict[str, float]: """Simulate getting weather data for a city. @@ -87,13 +92,15 @@ def get_weather(city: str) -> Dict[str, float]: } -@step +@step(enable_cache=False) def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: """Use LLM to analyze weather and provide intelligent recommendations. In run-only mode, this step receives weather data via in-memory handoff and returns analysis with no database or filesystem writes. """ + import time + temp = weather_data["temperature"] humidity = weather_data["humidity"] wind = weather_data["wind_speed"] @@ -103,11 +110,12 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: client = None if pipeline_state: + logger.debug("Pipeline state is a PipelineState") assert isinstance(pipeline_state, PipelineState), ( "Pipeline state is not a PipelineState" ) client = pipeline_state.client - + logger.debug("Client is %s", client) if client: # Create a prompt for the LLM weather_prompt = f"""You are a weather expert AI assistant. Analyze the following weather data for {city} and provide detailed insights and recommendations. @@ -126,9 +134,10 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: 5. Any weather warnings or tips Keep your response concise but informative.""" - + logger.info("[LLM] Starting OpenAI request for city=%s", city) + t0 = time.perf_counter() response = client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-5-mini", messages=[ { "role": "system", @@ -136,9 +145,9 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: }, {"role": "user", "content": weather_prompt}, ], - max_tokens=300, - temperature=0.7, ) + dt = time.perf_counter() - t0 + logger.info("[LLM] OpenAI request finished in %.3fs", dt) llm_analysis = response.choices[0].message.content @@ -214,6 +223,7 @@ def analyze_weather_with_llm(weather_data: Dict[str, float], city: str) -> str: @pipeline( on_init=init_hook, + capture=Capture(memory_only=True), settings={ "docker": docker_settings, "deployer.gcp": { @@ -264,6 +274,8 @@ def weather_agent_pipeline(city: str = "London") -> str: # Create deployment without running deployment = weather_agent_pipeline._create_deployment() + weather_agent_pipeline() + print("\n✅ Pipeline deployed for run-only serving!") print(f"📋 Deployment ID: {deployment.id}") print("\n🚀 Start serving with millisecond latency:") diff --git a/src/zenml/capture/config.py b/src/zenml/capture/config.py new file mode 100644 index 00000000000..16e36dce1d1 --- /dev/null +++ b/src/zenml/capture/config.py @@ -0,0 +1,38 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Capture configuration for ZenML (single, typed).""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Capture: + """Single capture configuration. + + Semantics are derived from context: + - Batch (orchestrated) runs use blocking publishes. + - Serving uses async publishes; `memory_only` switches to in-process handoff. + + Only observability toggles are exposed; they never affect dataflow except + `memory_only`, which is serving-only and ignored elsewhere. + """ + + # Serving-only: run without DB/artifact persistence using in-process handoff + memory_only: bool = False + # Observability toggles + code: bool = True + logs: bool = True + metadata: bool = True + visualizations: bool = True + metrics: bool = True diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index a4554761d37..f4efefc933f 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -26,6 +26,7 @@ ) from zenml import __version__ +from zenml.capture.config import Capture from zenml.config.base_settings import BaseSettings, ConfigurationLevel from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.config.pipeline_run_configuration import PipelineRunConfiguration @@ -150,6 +151,26 @@ def compile( pipeline_spec=pipeline_spec, ) + # Populate canonical capture fields from typed pipeline configuration + cap: Optional[Capture] = pipeline.configuration.capture + mem_only = bool(getattr(cap, "memory_only", False)) if cap else False + code = bool(getattr(cap, "code", True)) if cap else True + logs = bool(getattr(cap, "logs", True)) if cap else True + metadata_enabled = ( + bool(getattr(cap, "metadata", True)) if cap else True + ) + visuals = bool(getattr(cap, "visualizations", True)) if cap else True + metrics = bool(getattr(cap, "metrics", True)) if cap else True + try: + setattr(deployment, "capture_memory_only", mem_only) + setattr(deployment, "capture_code", code) + setattr(deployment, "capture_logs", logs) + setattr(deployment, "capture_metadata", metadata_enabled) + setattr(deployment, "capture_visualizations", visuals) + setattr(deployment, "capture_metrics", metrics) + except Exception: + pass + logger.debug("Compiled pipeline deployment: %s", deployment) return deployment @@ -202,6 +223,7 @@ def _apply_run_configuration( enable_artifact_visualization=config.enable_artifact_visualization, enable_step_logs=config.enable_step_logs, enable_pipeline_logs=config.enable_pipeline_logs, + capture=config.capture, settings=config.settings, tags=config.tags, extra=config.extra, diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index d9bec935859..5ba44eff4e3 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -18,6 +18,7 @@ from pydantic import SerializeAsAny, field_validator +from zenml.capture.config import Capture from zenml.config.cache_policy import CachePolicyWithValidator from zenml.config.constants import DOCKER_SETTINGS_KEY, RESOURCE_SETTINGS_KEY from zenml.config.retry_config import StepRetryConfig @@ -43,6 +44,8 @@ class PipelineConfigurationUpdate(StrictBaseModel): enable_artifact_visualization: Optional[bool] = None enable_step_logs: Optional[bool] = None enable_pipeline_logs: Optional[bool] = None + # Capture configuration (typed only) + capture: Optional[Capture] = None settings: Dict[str, SerializeAsAny[BaseSettings]] = {} tags: Optional[List[Union[str, "Tag"]]] = None extra: Dict[str, Any] = {} @@ -87,6 +90,18 @@ class PipelineConfiguration(PipelineConfigurationUpdate): name: str + @field_validator("capture") + @classmethod + def validate_capture_mode( + cls, value: Optional[Capture] + ) -> Optional[Capture]: + """Validates the capture config (typed only).""" + if value is None: + return value + if isinstance(value, Capture): + return value + raise ValueError("'capture' must be a typed Capture.") + @field_validator("name") @classmethod def ensure_pipeline_name_allowed(cls, name: str) -> str: diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index 9b8ec30275f..b8b3efee959 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -18,6 +18,7 @@ from pydantic import Field, SerializeAsAny +from zenml.capture.config import Capture from zenml.config.base_settings import BaseSettings from zenml.config.cache_policy import CachePolicyWithValidator from zenml.config.retry_config import StepRetryConfig @@ -70,6 +71,11 @@ class PipelineRunConfiguration( union_mode="left_to_right", description="The build to use for the pipeline run.", ) + # Optional typed capture override per run (no dicts/strings) + capture: Optional[Capture] = Field( + default=None, + description="The capture to use for the pipeline run.", + ) steps: Optional[Dict[str, StepConfigurationUpdate]] = Field( default=None, description="Configurations for the steps of the pipeline run.", diff --git a/src/zenml/deployers/serving/service.py b/src/zenml/deployers/serving/service.py index bcb0face929..254db972fa5 100644 --- a/src/zenml/deployers/serving/service.py +++ b/src/zenml/deployers/serving/service.py @@ -19,17 +19,31 @@ """ import asyncio +import inspect +import json +import os import time +import traceback from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional from uuid import UUID, uuid4 +import numpy as np + from zenml.client import Client from zenml.integrations.registry import integration_registry from zenml.logger import get_logger from zenml.models import PipelineDeploymentResponse -from zenml.orchestrators import utils as orchestrator_utils from zenml.orchestrators.topsort import topsorted_layers +from zenml.orchestrators.utils import ( + extract_return_contract, + is_tracking_disabled, + response_tap_clear, + response_tap_get_all, + set_pipeline_state, + set_return_targets, + set_serving_context, +) from zenml.stack import Stack from zenml.utils import source_utils @@ -126,8 +140,6 @@ async def initialize(self) -> None: except Exception as e: logger.error(f"❌ Failed to initialize service: {str(e)}") logger.error(f" Error type: {type(e)}") - import traceback - logger.error(f" Traceback: {traceback.format_exc()}") raise @@ -191,10 +203,6 @@ def _extract_parameter_schema(self) -> Dict[str, Any]: self.deployment.pipeline_configuration, "spec", None ) if pipeline_spec and getattr(pipeline_spec, "source", None): - import inspect - - from zenml.utils import source_utils - # Load the pipeline function pipeline_func = source_utils.load(pipeline_spec.source) @@ -285,16 +293,12 @@ def _serialize_for_json(self, value: Any) -> Any: JSON-serializable representation of the value """ try: - import json - # Handle common ML types that aren't JSON serializable if hasattr(value, "tolist"): # numpy arrays, pandas Series return value.tolist() elif hasattr(value, "to_dict"): # pandas DataFrames return value.to_dict() elif hasattr(value, "__array__"): # numpy-like arrays - import numpy as np - return np.asarray(value).tolist() # Test if it's already JSON serializable @@ -321,26 +325,27 @@ async def execute_pipeline( logger.info("Starting pipeline execution") # Set up response capture - orchestrator_utils.response_tap_clear() + response_tap_clear() self._setup_return_targets() try: # Resolve request parameters resolved_params = self._resolve_parameters(parameters) + # Expose resolved params to launcher/runner via env for memory-only path + os.environ["ZENML_SERVING_REQUEST_PARAMS"] = json.dumps( + resolved_params + ) + # Expose pipeline state via serving context var + set_pipeline_state(self.pipeline_state) # Get deployment and check if we're in no-capture mode deployment = self.deployment - _ = orchestrator_utils.is_tracking_disabled( + _ = is_tracking_disabled( deployment.pipeline_configuration.settings ) - # Set serving capture default for this request (no model mutations needed) - import os - - original_capture_default = os.environ.get( - "ZENML_SERVING_CAPTURE_DEFAULT" - ) - os.environ["ZENML_SERVING_CAPTURE_DEFAULT"] = "none" + # Mark serving context for the orchestrator/launcher + set_serving_context(True) # Build execution order using the production-tested topsort utility steps = deployment.step_configurations @@ -374,8 +379,9 @@ async def execute_pipeline( orchestrator = stack.orchestrator # Ensure a stable run id for StepLauncher to reuse the same PipelineRun + run_uuid = str(uuid4()) if hasattr(orchestrator, "_orchestrator_run_id"): - setattr(orchestrator, "_orchestrator_run_id", str(uuid4())) + setattr(orchestrator, "_orchestrator_run_id", run_uuid) # No serving overrides population in local orchestrator path @@ -387,16 +393,25 @@ async def execute_pipeline( finally: orchestrator._cleanup_run() - # Restore original capture default environment variable - if original_capture_default is None: - os.environ.pop("ZENML_SERVING_CAPTURE_DEFAULT", None) - else: - os.environ["ZENML_SERVING_CAPTURE_DEFAULT"] = ( - original_capture_default + # Clear serving context marker + set_serving_context(False) + # Clear request params env and shared runtime state + os.environ.pop("ZENML_SERVING_REQUEST_PARAMS", None) + set_pipeline_state(None) + # No per-request capture override to clear + try: + from zenml.orchestrators.runtime_manager import ( + clear_shared_runtime, + reset_memory_runtime_for_run, ) + reset_memory_runtime_for_run(run_uuid) + clear_shared_runtime() + except Exception: + pass + # Get captured outputs from response tap - outputs = orchestrator_utils.response_tap_get_all() + outputs = response_tap_get_all() execution_time = time.time() - start self._update_execution_stats(True, execution_time) @@ -435,7 +450,7 @@ async def execute_pipeline( } finally: # Clean up response tap - orchestrator_utils.response_tap_clear() + response_tap_clear() async def submit_pipeline( self, @@ -574,7 +589,7 @@ def _setup_return_targets(self) -> None: else None ) contract = ( - orchestrator_utils.extract_return_contract(pipeline_source) + extract_return_contract(pipeline_source) if pipeline_source else None ) @@ -614,12 +629,12 @@ def _setup_return_targets(self) -> None: ) logger.debug(f"Return targets: {return_targets}") - orchestrator_utils.set_return_targets(return_targets) + set_return_targets(return_targets) except Exception as e: logger.warning(f"Failed to setup return targets: {e}") # Set empty targets as fallback - orchestrator_utils.set_return_targets({}) + set_return_targets({}) def is_healthy(self) -> bool: """Check if the service is healthy and ready to serve requests. diff --git a/src/zenml/execution/__init__.py b/src/zenml/execution/__init__.py new file mode 100644 index 00000000000..7b2a01ce5d0 --- /dev/null +++ b/src/zenml/execution/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Execution runtime abstractions. + +This module defines the runtime interface used by the step runner / launcher +to interact with artifacts, metadata, and server updates. It is introduced as +an internal scaffolding to consolidate execution-time responsibilities behind +one facade without changing current behavior. + +NOTE: This is an internal API and subject to change. +""" + diff --git a/src/zenml/execution/default_runtime.py b/src/zenml/execution/default_runtime.py new file mode 100644 index 00000000000..6a6dfaea11c --- /dev/null +++ b/src/zenml/execution/default_runtime.py @@ -0,0 +1,291 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Default step runtime implementation (blocking publish, standard persistence).""" + +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact +from zenml.client import Client +from zenml.enums import ArtifactSaveType +from zenml.execution.step_runtime import BaseStepRuntime +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.materializers.materializer_registry import materializer_registry +from zenml.models import ArtifactVersionResponse +from zenml.steps.step_context import get_step_context +from zenml.utils import materializer_utils, source_utils, tag_utils +from zenml.utils.typing_utils import get_origin, is_union + +if TYPE_CHECKING: + from zenml.artifact_stores import BaseArtifactStore + from zenml.config.step_configurations import Step + from zenml.materializers.base_materializer import BaseMaterializer + from zenml.models import PipelineRunResponse, StepRunResponse + from zenml.models.v2.core.step_run import StepRunInputResponse + from zenml.stack import Stack + from zenml.steps.utils import OutputSignature + +logger = get_logger(__name__) + + +class DefaultStepRuntime(BaseStepRuntime): + """Default runtime delegating to existing ZenML utilities. + + This keeps current behavior intact while providing a single place for the + step runner to call into. It intentionally mirrors logic from + `step_runner.py` and `orchestrators/input_utils.py`. + """ + + # --- Input Resolution --- + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, "StepRunInputResponse"]: + """Resolve step inputs. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: Optional map of step runs. + + Returns: + Mapping from input name to resolved step run input. + """ + # Local import to avoid circular import issues + from zenml.orchestrators import input_utils + + return input_utils.resolve_step_inputs( + step=step, pipeline_run=pipeline_run, step_runs=step_runs + ) + + # --- Artifact Load --- + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + + Returns: + The loaded Python value for the input artifact. + """ + # Skip materialization for `UnmaterializedArtifact`. + if data_type == UnmaterializedArtifact: + return UnmaterializedArtifact( + **artifact.get_hydrated_version().model_dump() + ) + + if data_type in (None, Any) or is_union(get_origin(data_type)): + # Use the stored artifact datatype when function annotation is not specific + data_type = source_utils.load(artifact.data_type) + + materializer_class: Type[BaseMaterializer] = ( + source_utils.load_and_validate_class( + artifact.materializer, expected_class=BaseMaterializer + ) + ) + + def _load(artifact_store: "BaseArtifactStore") -> Any: + materializer: BaseMaterializer = materializer_class( + uri=artifact.uri, artifact_store=artifact_store + ) + materializer.validate_load_type_compatibility(data_type) + return materializer.load(data_type=data_type) + + if artifact.artifact_store_id == stack.artifact_store.id: + stack.artifact_store._register() + return _load(artifact_store=stack.artifact_store) + else: + # Local import to avoid circular import issues + from zenml.orchestrators.utils import ( + register_artifact_store_filesystem, + ) + + with register_artifact_store_filesystem( + artifact.artifact_store_id + ) as target_store: + return _load(artifact_store=target_store) + + # --- Artifact Store --- + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + Mapping from output name to stored artifact version. + + Raises: + RuntimeError: If artifact batch creation fails after retries or + the number of responses does not match requests. + """ + # Apply capture toggles for metadata and visualizations + artifact_metadata_enabled = artifact_metadata_enabled and bool( + getattr(self, "_metadata_enabled", True) + ) + artifact_visualization_enabled = ( + artifact_visualization_enabled + and bool(getattr(self, "_visualizations_enabled", True)) + ) + + step_context = get_step_context() + artifact_requests: List[Any] = [] + + for output_name, return_value in output_data.items(): + data_type = type(return_value) + materializer_classes = output_materializers[output_name] + if materializer_classes: + materializer_class: Type[BaseMaterializer] = ( + materializer_utils.select_materializer( + data_type=data_type, + materializer_classes=materializer_classes, + ) + ) + else: + # Runtime selection if no explicit materializer recorded + default_materializer_source = ( + step_context.step_run.config.outputs[ + output_name + ].default_materializer_source + if step_context and step_context.step_run + else None + ) + + if default_materializer_source: + default_materializer_class: Type[BaseMaterializer] = ( + source_utils.load_and_validate_class( + default_materializer_source, + expected_class=BaseMaterializer, + ) + ) + materializer_registry.default_materializer = ( + default_materializer_class + ) + + materializer_class = materializer_registry[data_type] + + uri = output_artifact_uris[output_name] + artifact_config = output_annotations[output_name].artifact_config + + artifact_type = None + if artifact_config is not None: + has_custom_name = bool(artifact_config.name) + version = artifact_config.version + artifact_type = artifact_config.artifact_type + else: + has_custom_name, version = False, None + + # Name resolution mirrors existing behavior + if has_custom_name: + artifact_name = output_name + else: + if step_context.pipeline_run.pipeline: + pipeline_name = step_context.pipeline_run.pipeline.name + else: + pipeline_name = "unlisted" + step_name = step_context.step_run.name + artifact_name = f"{pipeline_name}::{step_name}::{output_name}" + + # Collect user metadata and tags + user_metadata = step_context.get_output_metadata(output_name) + tags = step_context.get_output_tags(output_name) + if step_context.pipeline_run.config.tags is not None: + for tag in step_context.pipeline_run.config.tags: + if isinstance(tag, tag_utils.Tag) and tag.cascade is True: + tags.append(tag.name) + + # Store artifact data and prepare a request to the server. + from zenml.artifacts.utils import ( + _store_artifact_data_and_prepare_request, + ) + + artifact_request = _store_artifact_data_and_prepare_request( + name=artifact_name, + data=return_value, + materializer_class=materializer_class, + uri=uri, + artifact_type=artifact_type, + store_metadata=artifact_metadata_enabled, + store_visualizations=artifact_visualization_enabled, + has_custom_name=has_custom_name, + version=version, + tags=tags, + save_type=ArtifactSaveType.STEP_OUTPUT, + metadata=user_metadata, + ) + artifact_requests.append(artifact_request) + + max_retries = 2 + delay = 1.0 + + for attempt in range(max_retries + 1): + try: + responses = Client().zen_store.batch_create_artifact_versions( + artifact_requests + ) + if len(responses) != len(artifact_requests): + raise RuntimeError( + f"Artifact batch creation returned {len(responses)}/{len(artifact_requests)} responses" + ) + return dict(zip(output_data.keys(), responses)) + except Exception as e: + if attempt < max_retries: + logger.warning( + "Artifact creation attempt %s failed: %s. Retrying in %.1fs...", + attempt + 1, + e, + delay, + ) + time.sleep(delay) + delay *= 1.5 + else: + logger.error( + "Failed to create artifacts after %s attempts: %s. Failing step to avoid inconsistency.", + max_retries + 1, + e, + ) + raise + + # TODO(beta->prod): Align with server to provide atomic batch create or + # compensating deletes. Consider idempotent requests and retriable error + # categories with jittered backoff. + raise RuntimeError( + "Artifact creation failed unexpectedly without raising" + ) diff --git a/src/zenml/execution/factory.py b/src/zenml/execution/factory.py new file mode 100644 index 00000000000..c8958010f42 --- /dev/null +++ b/src/zenml/execution/factory.py @@ -0,0 +1,51 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Factory to construct a step runtime based on context and capture.""" + +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.execution.memory_runtime import MemoryStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime + + +def get_runtime( + *, serving: bool, memory_only: bool, metrics_enabled: bool = True +) -> BaseStepRuntime: + """Return a runtime implementation for the given context. + + Args: + serving: True if executing in serving context. + memory_only: True if serving should use in-process handoff. + metrics_enabled: Enable runtime metrics collection (realtime only). + + Returns: + The runtime implementation. + """ + if not serving: + return DefaultStepRuntime() + if memory_only: + return MemoryStepRuntime() + + # Import here to avoid circular imports + from zenml.execution.realtime_runtime import RealtimeStepRuntime + + rt = RealtimeStepRuntime() + # Gate metrics at the runtime if supported + if not metrics_enabled: + try: + setattr( + rt, "_metrics_disabled", True + ) # runtime may optionally read this + except Exception: + pass + return rt diff --git a/src/zenml/execution/memory_runtime.py b/src/zenml/execution/memory_runtime.py new file mode 100644 index 00000000000..25923bd0b76 --- /dev/null +++ b/src/zenml/execution/memory_runtime.py @@ -0,0 +1,336 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Memory-only step runtime (in-process handoff, no DB/FS persistence).""" + +import threading +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Type + +from zenml.execution.step_runtime import BaseStepRuntime +from zenml.logger import get_logger +from zenml.steps.step_context import get_step_context +from zenml.utils import string_utils + +if TYPE_CHECKING: + from zenml.config.step_configurations import Step + from zenml.models import PipelineRunResponse, StepRunResponse + +logger = get_logger(__name__) + + +class MemoryStepRuntime(BaseStepRuntime): + """Pure in-memory execution runtime: no server calls, no persistence. + + Instance-scoped store to isolate requests. Values are accessible within the + same process for the same run id and step chain only. + """ + + @staticmethod + def make_handle_id(run_id: str, step_name: str, output_name: str) -> str: + """Make a handle ID for an output artifact. + + Args: + run_id: The run ID. + step_name: The step name. + output_name: The output name. + + Returns: + The handle ID. + """ + return f"mem://{run_id}/{step_name}/{output_name}" + + @staticmethod + def parse_handle_id(handle_id: str) -> Tuple[str, str, str]: + """Parse a handle ID for an output artifact. + + Args: + handle_id: The handle ID. + + Returns: + The run ID, step name, and output name. + + Raises: + ValueError: If the handle id is malformed. + """ + if not isinstance(handle_id, str) or not handle_id.startswith( + "mem://" + ): + raise ValueError("Invalid memory handle id") + rest = handle_id[len("mem://") :] + # split into exactly 3 parts: run_id, step_name, output_name + parts = rest.split("/", 2) + if len(parts) != 3: + raise ValueError("Invalid memory handle id") + run_id, step_name, output_name = parts + # basic sanitization + for p in (run_id, step_name, output_name): + if not p or "\n" in p or "\r" in p: + raise ValueError("Invalid memory handle component") + return run_id, step_name, output_name + + class Handle: + """A handle for an output artifact.""" + + def __init__(self, id: str) -> None: + """Initialize the handle. + + Args: + id: The handle ID. + """ + self.id = id + + # Instance-scoped context for handle resolution (set by launcher) + def __init__(self) -> None: + """Initialize the memory runtime.""" + super().__init__() + self._ctx_run_id: Optional[str] = None + self._ctx_substitutions: Dict[str, str] = {} + self._active_run_ids: set[str] = set() + # Instance-scoped storage and locks per run_id + self._store: Dict[str, Dict[Tuple[str, str], Any]] = {} + self._run_locks: Dict[str, Any] = {} + self._global_lock: Any = threading.RLock() + + def set_context( + self, *, run_id: str, substitutions: Optional[Dict[str, str]] = None + ) -> None: + """Set current memory-only context for handle resolution. + + Args: + run_id: The run ID. + substitutions: The substitutions. + """ + self._ctx_run_id = run_id + self._ctx_substitutions = substitutions or {} + try: + if run_id: + self._active_run_ids.add(run_id) + except Exception: + pass + + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, Any]: + """Resolve step inputs by constructing in-memory handles. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: The step runs to resolve inputs for. + + Returns: + A mapping of input name to MemoryStepRuntime.Handle. + """ + run_id = self._ctx_run_id or str(getattr(pipeline_run, "id", "local")) + subs = self._ctx_substitutions or {} + handles: Dict[str, Any] = {} + for name, input_ in step.spec.inputs.items(): + resolved_output_name = string_utils.format_name_template( + input_.output_name, substitutions=subs + ) + handle_id = self.make_handle_id( + run_id=run_id, + step_name=input_.step_name, + output_name=resolved_output_name, + ) + handles[name] = MemoryStepRuntime.Handle(handle_id) + return handles + + def load_input_artifact( + self, *, artifact: Any, data_type: Type[Any], stack: Any + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + + Returns: + The loaded artifact. + + Raises: + ValueError: If the memory handle id is invalid or malformed. + """ + handle_id_any = getattr(artifact, "id", None) + if not isinstance(handle_id_any, str): + raise ValueError("Invalid memory handle id") + run_id, step_name, output_name = self.parse_handle_id(handle_id_any) + # Use per-run lock to avoid cross-run interference + with self._global_lock: + rlock = self._run_locks.setdefault(run_id, threading.RLock()) + with rlock: + return self._store.get(run_id, {}).get((step_name, output_name)) + + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type[Any], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, Any], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, Any]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + The stored artifacts. + """ + ctx = get_step_context() + run_id = str(getattr(ctx.pipeline_run, "id", "local")) + try: + if run_id: + self._active_run_ids.add(run_id) + except Exception: + pass + step_name = str(getattr(ctx.step_run, "name", "step")) + handles: Dict[str, Any] = {} + with self._global_lock: + rlock = self._run_locks.setdefault(run_id, threading.RLock()) + with rlock: + rr = self._store.setdefault(run_id, {}) + for output_name, value in output_data.items(): + rr[(step_name, output_name)] = value + handle_id = self.make_handle_id(run_id, step_name, output_name) + handles[output_name] = MemoryStepRuntime.Handle(handle_id) + return handles + + def compute_cache_key( + self, + *, + step: Any, + input_artifacts: Mapping[str, Any], + artifact_store: Any, + project_id: Any, + ) -> str: + """Compute a cache key. + + Args: + step: The step to compute the cache key for. + input_artifacts: The input artifacts for the step. + artifact_store: The artifact store to compute the cache key for. + project_id: The project ID to compute the cache key for. + + Returns: + The computed cache key. + """ + return "" + + def get_cached_step_run(self, *, cache_key: str) -> None: + """Get a cached step run. + + Args: + cache_key: The cache key to get the cached step run for. + + Returns: + The cached step run if available, otherwise None. + """ + return None + + def publish_pipeline_run_metadata( + self, *, pipeline_run_id: Any, pipeline_run_metadata: Any + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + return + + def publish_step_run_metadata( + self, *, step_run_id: Any, step_run_metadata: Any + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + return + + def publish_successful_step_run( + self, *, step_run_id: Any, output_artifact_ids: Any + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + return + + def publish_failed_step_run(self, *, step_run_id: Any) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + return + + def start(self) -> None: + """Start the memory runtime.""" + return + + def on_step_start(self) -> None: + """Optional hook when a step starts execution.""" + return + + def flush(self) -> None: + """Flush the memory runtime.""" + return + + def on_step_end(self) -> None: + """Optional hook when a step ends execution.""" + return + + def shutdown(self) -> None: + """Shutdown the memory runtime.""" + return + + def __del__(self) -> None: # noqa: D401 + """Best-effort cleanup of per-run memory when GC collects the runtime.""" + try: + for run_id in list(self._active_run_ids): + try: + self.reset(run_id) + except Exception: + pass + except Exception: + pass + + # --- Unified path helpers --- + def reset(self, run_id: str) -> None: + """Clear all in-memory data associated with a specific run. + + Args: + run_id: The run id to clear. + """ + with self._global_lock: + try: + self._store.pop(run_id, None) + finally: + self._run_locks.pop(run_id, None) diff --git a/src/zenml/execution/realtime_runtime.py b/src/zenml/execution/realtime_runtime.py new file mode 100644 index 00000000000..a982b70dc92 --- /dev/null +++ b/src/zenml/execution/realtime_runtime.py @@ -0,0 +1,619 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Realtime runtime with simple in-memory caching and async updates. + +This implementation prioritizes in-memory loads when available and otherwise +delegates to the default runtime persistence. It lays groundwork for future +write-behind persistence without changing current behavior. +""" + +import os +import queue +import threading +import time +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.models import ArtifactVersionResponse +from zenml.orchestrators import publish_utils +from zenml.stack.stack import Stack +from zenml.steps.utils import OutputSignature + +if TYPE_CHECKING: + from uuid import UUID + + from zenml.metadata.metadata_types import MetadataType + + +class RealtimeStepRuntime(DefaultStepRuntime): + """Realtime runtime optimized for low-latency loads via memory cache. + + TODO(beta->prod): scale background publishing either by + - adding a small multi-worker thread pool (ThreadPoolExecutor), or + - migrating to an asyncio-based runtime once the client/publish calls have + async variants and we want an async mode in serving. + Both paths must keep bounded backpressure and orderly shutdown. + """ + + def __init__( + self, + ttl_seconds: Optional[int] = None, + max_entries: Optional[int] = None, + ) -> None: + """Initialize the realtime runtime. + + Args: + ttl_seconds: The TTL in seconds. + max_entries: The maximum number of entries in the cache. + """ + super().__init__() + # Simple LRU cache with TTL + self._cache: OrderedDict[str, Tuple[Any, float]] = OrderedDict() + self._lock = threading.RLock() + # Event queue: (kind, args, kwargs) + Event = Tuple[str, Tuple[Any, ...], Dict[str, Any]] + self._q: queue.Queue[Event] = queue.Queue(maxsize=1024) + # TODO(beta->prod): when scaling per-process publishing, prefer either + # (1) a small thread pool consuming this queue, or (2) an asyncio loop + # with an asyncio.Queue and async workers, once the client has async + # publish calls and we opt into async serving. + self._worker: Optional[threading.Thread] = None + self._stop = threading.Event() + self._errors_since_last_flush: int = 0 + self._total_errors: int = 0 + self._last_error: Optional[BaseException] = None + self._error_reported: bool = False + self._last_report_ts: float = 0.0 + self._logger = get_logger(__name__) + self._queued_count: int = 0 + self._processed_count: int = 0 + # Metrics: cache and op latencies + self._cache_hits: int = 0 + self._cache_misses: int = 0 + self._op_latencies: List[float] = [] + # TODO(beta->prod): add process memory monitoring and expose worker + # liveness/health at the service layer. + # Tunables via env: TTL seconds and max entries + # Options precedence: explicit args > env > defaults + if ttl_seconds is not None: + self._ttl_seconds = int(ttl_seconds) + else: + try: + self._ttl_seconds = int( + os.getenv("ZENML_RT_CACHE_TTL_SECONDS", "60") + ) + except Exception: + self._ttl_seconds = 60 + if max_entries is not None: + self._max_entries = int(max_entries) + else: + try: + self._max_entries = int( + os.getenv("ZENML_RT_CACHE_MAX_ENTRIES", "256") + ) + except Exception: + self._max_entries = 256 + # Circuit breaker controls + try: + self._cb_threshold = float( + os.getenv("ZENML_RT_CB_ERR_THRESHOLD", "0.1") + ) + self._cb_min_events = int( + os.getenv("ZENML_RT_CB_MIN_EVENTS", "100") + ) + self._cb_open_seconds = float( + os.getenv("ZENML_RT_CB_OPEN_SECONDS", "300") + ) + except Exception: + self._cb_threshold = 0.1 + self._cb_min_events = 100 + self._cb_open_seconds = 300.0 + self._cb_errors_window: int = 0 + self._cb_total_window: int = 0 + self._cb_open_until_ts: float = 0.0 + # Error report interval (seconds) + try: + self._err_report_interval = float( + os.getenv("ZENML_RT_ERR_REPORT_INTERVAL", "15") + ) + except Exception: + self._err_report_interval = 15.0 + # Serving is async by default (non-blocking) + self._flush_on_step_end: bool = False + + # --- lifecycle --- + def start(self) -> None: + """Start the realtime runtime.""" + if self._worker is not None: + return + + def _run() -> None: + idle_sleep = 0.05 + while not self._stop.is_set(): + try: + kind, args, kwargs = self._q.get(timeout=idle_sleep) + except queue.Empty: + # Opportunistic cache sweep: evict expired from head + self._sweep_expired() + idle_sleep = min(idle_sleep * 2.0, 2.0) + continue + try: + start = time.time() + if kind == "pipeline_metadata": + publish_utils.publish_pipeline_run_metadata( + *args, **kwargs + ) + elif kind == "step_metadata": + publish_utils.publish_step_run_metadata( + *args, **kwargs + ) + elif kind == "step_success": + publish_utils.publish_successful_step_run( + *args, **kwargs + ) + elif kind == "step_failed": + publish_utils.publish_failed_step_run(*args, **kwargs) + except BaseException as e: # noqa: BLE001 + with self._lock: + self._errors_since_last_flush += 1 + self._total_errors += 1 + self._last_error = e + self._logger.warning( + "Realtime runtime failed to publish '%s': %s", kind, e + ) + finally: + with self._lock: + self._processed_count += 1 + # Update circuit breaker window + self._cb_total_window += 1 + if self._last_error is not None: + self._cb_errors_window += 1 + # Record latency (bounded sample) + try: + self._op_latencies.append( + max(0.0, time.time() - start) + ) + if len(self._op_latencies) > 512: + self._op_latencies = self._op_latencies[-256:] + except Exception: + pass + self._q.task_done() + idle_sleep = 0.01 + + self._worker = threading.Thread( + target=_run, name="zenml-realtime-runtime", daemon=True + ) + self._worker.start() + + def on_step_start(self) -> None: + """Optional hook when a step begins execution.""" + # no-op for now + return + + # Prefer in-memory values if available + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack of the artifact. + + Returns: + The loaded artifact. + """ + key = str(artifact.id) + with self._lock: + if key in self._cache: + value, expires_at = self._cache.get(key, (None, 0)) + now = time.time() + if now <= expires_at: + # Touch entry for LRU + self._cache.move_to_end(key) + self._cache_hits += 1 + return value + else: + # Expired + try: + del self._cache[key] + except KeyError: + pass + self._cache_misses += 1 + + # Fallback to default loading + return super().load_input_artifact( + artifact=artifact, data_type=data_type, stack=stack + ) + + # Store synchronously (behavior parity), and cache the raw values in memory + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Store output artifacts. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + + Returns: + The stored artifacts. + """ + responses = super().store_output_artifacts( + output_data=output_data, + output_materializers=output_materializers, + output_artifact_uris=output_artifact_uris, + output_annotations=output_annotations, + artifact_metadata_enabled=artifact_metadata_enabled, + artifact_visualization_enabled=artifact_visualization_enabled, + ) + + # Cache by artifact id for later fast loads with TTL and LRU bounds + with self._lock: + now = time.time() + for name, resp in responses.items(): + if name in output_data: + expires_at = now + max(0, self._ttl_seconds) + self._cache[str(resp.id)] = (output_data[name], expires_at) + # Touch to end (most recently used) + self._cache.move_to_end(str(resp.id)) + # Enforce size bound + while len(self._cache) > max(1, self._max_entries): + try: + self._cache.popitem(last=False) # Evict LRU + except KeyError: + break + + return responses + + # --- async server updates --- + def publish_pipeline_run_metadata( + self, + *, + pipeline_run_id: "UUID", + pipeline_run_metadata: Dict["UUID", Dict[str, "MetadataType"]], + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + # Inline if circuit open, else enqueue + if self._should_process_inline(): + publish_utils.publish_pipeline_run_metadata( + pipeline_run_id=pipeline_run_id, + pipeline_run_metadata=pipeline_run_metadata, + ) + return + self._q.put( + ( + "pipeline_metadata", + (), + { + "pipeline_run_id": pipeline_run_id, + "pipeline_run_metadata": pipeline_run_metadata, + }, + ) + ) + with self._lock: + self._queued_count += 1 + + def publish_step_run_metadata( + self, + *, + step_run_id: "UUID", + step_run_metadata: Dict["UUID", Dict[str, "MetadataType"]], + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + if self._should_process_inline(): + publish_utils.publish_step_run_metadata( + step_run_id=step_run_id, step_run_metadata=step_run_metadata + ) + return + self._q.put( + ( + "step_metadata", + (), + { + "step_run_id": step_run_id, + "step_run_metadata": step_run_metadata, + }, + ) + ) + with self._lock: + self._queued_count += 1 + + def publish_successful_step_run( + self, + *, + step_run_id: "UUID", + output_artifact_ids: Dict[str, List["UUID"]], + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + if self._should_process_inline(): + publish_utils.publish_successful_step_run( + step_run_id=step_run_id, + output_artifact_ids=output_artifact_ids, + ) + return + self._q.put( + ( + "step_success", + (), + { + "step_run_id": step_run_id, + "output_artifact_ids": output_artifact_ids, + }, + ) + ) + with self._lock: + self._queued_count += 1 + + def publish_failed_step_run( + self, + *, + step_run_id: "UUID", + ) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + if self._should_process_inline(): + publish_utils.publish_failed_step_run(step_run_id) + return + try: + self._q.put_nowait( + ("step_failed", (), {"step_run_id": step_run_id}) + ) + with self._lock: + self._queued_count += 1 + except queue.Full: + self._logger.debug("Queue full, processing step_failed inline") + try: + publish_utils.publish_failed_step_run(step_run_id) + except Exception as e: + self._logger.warning("Inline processing failed: %s", e) + + def flush(self) -> None: + """Flush the realtime runtime by draining queued events synchronously. + + Raises: + RuntimeError: If background errors were encountered while draining. + """ + # Drain the queue in the calling thread to avoid waiting on the worker + while True: + try: + kind, args, kwargs = self._q.get_nowait() + except queue.Empty: + break + try: + if kind == "pipeline_metadata": + publish_utils.publish_pipeline_run_metadata( + *args, **kwargs + ) + elif kind == "step_metadata": + publish_utils.publish_step_run_metadata(*args, **kwargs) + elif kind == "step_success": + publish_utils.publish_successful_step_run(*args, **kwargs) + elif kind == "step_failed": + publish_utils.publish_failed_step_run(*args, **kwargs) + except BaseException as e: # noqa: BLE001 + with self._lock: + self._errors_since_last_flush += 1 + self._total_errors += 1 + self._last_error = e + self._logger.warning( + "Realtime runtime flush failed to publish '%s': %s", + kind, + e, + ) + finally: + with self._lock: + self._processed_count += 1 + try: + self._q.task_done() + except ValueError: + # If task_done called more than put() count due to races, ignore + pass + # Post-flush maintenance + self._sweep_expired() + with self._lock: + if self._errors_since_last_flush: + count = self._errors_since_last_flush + last = self._last_error + self._errors_since_last_flush = 0 + self._error_reported = True + raise RuntimeError( + f"Realtime runtime encountered {count} error(s) while publishing. Last error: {last}" + ) + + def on_step_end(self) -> None: + """Optional hook when a step ends execution.""" + # no-op for now + return + + def shutdown(self) -> None: + """Shutdown the realtime runtime. + + TODO(beta->prod): expose worker liveness/health signals to the service. + """ + # Wait for remaining tasks and stop + self.flush() + self._stop.set() + # Join worker with timeout + worker = self._worker + if worker is not None: + worker.join(timeout=15.0) + if worker.is_alive(): + self._logger.warning( + "Realtime runtime worker did not terminate gracefully within timeout." + ) + self._worker = None + + # Flush behavior controls + def set_flush_on_step_end(self, value: bool) -> None: + """Set the flush on step end behavior. + + Args: + value: The value to set. + """ + self._flush_on_step_end = bool(value) + + def should_flush_on_step_end(self) -> bool: + """Whether the runtime should flush on step end. + + Returns: + Whether the runtime should flush on step end. + """ + return self._flush_on_step_end + + def get_metrics(self) -> Dict[str, Any]: + """Return runtime metrics snapshot. + + Returns: + The runtime metrics snapshot. + """ + # TODO(beta->prod): export to an external sink (e.g., Prometheus) and + # expand with additional histograms / event counters as needed. + if bool(getattr(self, "_metrics_disabled", False)): + return {} + with self._lock: + queued = self._queued_count + processed = self._processed_count + failed_total = self._total_errors + ttl_seconds = getattr(self, "_ttl_seconds", None) + max_entries = getattr(self, "_max_entries", None) + cache_hits = self._cache_hits + cache_misses = self._cache_misses + latencies = list(self._op_latencies) + try: + depth = self._q.qsize() + except Exception: + depth = 0 + # Compute simple percentiles + p50 = p95 = p99 = 0.0 + if latencies: + s = sorted(latencies) + n = len(s) + p50 = s[int(0.5 * (n - 1))] + p95 = s[int(0.95 * (n - 1))] + p99 = s[int(0.99 * (n - 1))] + hit_rate = ( + float(cache_hits) / float(cache_hits + cache_misses) + if (cache_hits + cache_misses) > 0 + else 0.0 + ) + return { + "queued": queued, + "processed": processed, + "failed_total": failed_total, + "queue_depth": depth, + "ttl_seconds": ttl_seconds, + "max_entries": max_entries, + "cache_hits": cache_hits, + "cache_misses": cache_misses, + "cache_hit_rate": hit_rate, + "op_latency_p50_s": p50, + "op_latency_p95_s": p95, + "op_latency_p99_s": p99, + } + + # Surface background errors even when not flushing + def check_async_errors(self) -> None: + """Log and mark any background errors on an interval.""" + with self._lock: + if self._last_error: + now = time.time() + if (not self._error_reported) or ( + now - self._last_report_ts > self._err_report_interval + ): + self._logger.error( + "Background realtime runtime error: %s", + self._last_error, + ) + self._error_reported = True + self._last_report_ts = now + + # --- internal helpers --- + def _sweep_expired(self) -> None: + """Remove expired entries using a snapshot within a small time budget.""" + deadline = time.time() + 0.005 + with self._lock: + snapshot = list(self._cache.items()) + expired: List[str] = [] + now = time.time() + for key, (_val, expires_at) in snapshot: + if time.time() > deadline: + break + if now > expires_at: + expired.append(key) + if expired: + with self._lock: + for key in expired: + self._cache.pop(key, None) + + def _should_process_inline(self) -> bool: + """Return True if circuit breaker is open and we should publish inline. + + Returns: + True if inline processing should be used, False otherwise. + """ + with self._lock: + now = time.time() + if now < self._cb_open_until_ts: + return True + total = self._cb_total_window + errors = self._cb_errors_window + if total >= self._cb_min_events: + err_rate = (float(errors) / float(total)) if total > 0 else 0.0 + if err_rate >= self._cb_threshold: + self._cb_open_until_ts = now + self._cb_open_seconds + self._logger.warning( + "Realtime runtime circuit opened for %.0fs due to error rate %.2f", + self._cb_open_seconds, + err_rate, + ) + return True + return False diff --git a/src/zenml/execution/step_runtime.py b/src/zenml/execution/step_runtime.py new file mode 100644 index 00000000000..c1318927513 --- /dev/null +++ b/src/zenml/execution/step_runtime.py @@ -0,0 +1,262 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Step runtime facade for step execution responsibilities. + +This scaffolds a minimal, behavior-preserving runtime abstraction that the +step runner can call into for artifact I/O and input resolution. The default +implementation delegates to existing ZenML utilities. + +Enable usage by setting environment variable `ZENML_ENABLE_STEP_RUNTIME=true`. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Type +from uuid import UUID + +from zenml.logger import get_logger +from zenml.materializers.base_materializer import BaseMaterializer +from zenml.models import ArtifactVersionResponse + +# Note: avoid importing zenml.orchestrators modules at import time to prevent +# circular dependencies. Where needed, import locally within methods. + +if TYPE_CHECKING: + from zenml.artifact_stores import BaseArtifactStore + from zenml.config.step_configurations import Step + from zenml.materializers.base_materializer import BaseMaterializer + from zenml.models import PipelineRunResponse, StepRunResponse + from zenml.models.v2.core.step_run import StepRunInputResponse + from zenml.stack import Stack + from zenml.steps.utils import OutputSignature + +logger = get_logger(__name__) + + +class BaseStepRuntime(ABC): + """Abstract execution-time interface for step I/O and interactions. + + Implementations may optimize persistence, caching, logging, and server + updates based on capture policy. This base class only covers the minimal + responsibilities we want to centralize first. + """ + + @abstractmethod + def resolve_step_inputs( + self, + *, + step: "Step", + pipeline_run: "PipelineRunResponse", + step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + ) -> Dict[str, "StepRunInputResponse"]: + """Resolve input artifacts for the given step. + + Args: + step: The step to resolve inputs for. + pipeline_run: The pipeline run to resolve inputs for. + step_runs: The step runs to resolve inputs for. + + Returns: + The resolved inputs. + """ + + @abstractmethod + def load_input_artifact( + self, + *, + artifact: ArtifactVersionResponse, + data_type: Type[Any], + stack: "Stack", + ) -> Any: + """Load materialized value for an input artifact. + + Args: + artifact: The artifact to load. + data_type: The data type of the artifact. + stack: The stack to load the artifact from. + """ + + @abstractmethod + def store_output_artifacts( + self, + *, + output_data: Dict[str, Any], + output_materializers: Dict[str, Tuple[Type["BaseMaterializer"], ...]], + output_artifact_uris: Dict[str, str], + output_annotations: Dict[str, "OutputSignature"], + artifact_metadata_enabled: bool, + artifact_visualization_enabled: bool, + ) -> Dict[str, ArtifactVersionResponse]: + """Materialize and persist output artifacts and return their versions. + + Args: + output_data: The output data. + output_materializers: The output materializers. + output_artifact_uris: The output artifact URIs. + output_annotations: The output annotations. + artifact_metadata_enabled: Whether artifact metadata is enabled. + artifact_visualization_enabled: Whether artifact visualization is enabled. + """ + + # --- Cache Helpers (optional) --- + def compute_cache_key( + self, + *, + step: "Step", + input_artifacts: Mapping[str, "ArtifactVersionResponse"], + artifact_store: "BaseArtifactStore", + project_id: UUID, + ) -> str: + """Compute a cache key for a step using existing utilities. + + Default implementation delegates to `cache_utils`. + + Args: + step: The step to compute the cache key for. + input_artifacts: The input artifacts. + artifact_store: The artifact store to compute the cache key for. + project_id: The project ID to compute the cache key for. + + Returns: + The computed cache key. + """ + # Local import to avoid circular import issues + from zenml.orchestrators import cache_utils + + return cache_utils.generate_cache_key( + step=step, + input_artifacts=input_artifacts, + artifact_store=artifact_store, + project_id=project_id, + ) + + def get_cached_step_run( + self, *, cache_key: str + ) -> Optional["StepRunResponse"]: + """Return a cached step run if available. + + Default implementation delegates to `cache_utils`. + + Args: + cache_key: The cache key to get the cached step run for. + + Returns: + The cached step run if available, otherwise None. + """ + # Local import to avoid circular import issues + from zenml.orchestrators import cache_utils + + return cache_utils.get_cached_step_run(cache_key=cache_key) + + # --- Server update helpers (may be batched/async by implementations) --- + def start(self) -> None: + """Optional start hook for runtime lifecycles.""" + + def on_step_start(self) -> None: + """Optional hook when a step begins execution.""" + + def publish_pipeline_run_metadata( + self, *, pipeline_run_id: Any, pipeline_run_metadata: Any + ) -> None: + """Publish pipeline run metadata. + + Args: + pipeline_run_id: The pipeline run ID. + pipeline_run_metadata: The pipeline run metadata. + """ + if not bool(getattr(self, "_metadata_enabled", True)): + return + from zenml.orchestrators.publish_utils import ( + publish_pipeline_run_metadata as _pub_run_md, + ) + + _pub_run_md( + pipeline_run_id=pipeline_run_id, + pipeline_run_metadata=pipeline_run_metadata, + ) + + def publish_step_run_metadata( + self, *, step_run_id: Any, step_run_metadata: Any + ) -> None: + """Publish step run metadata. + + Args: + step_run_id: The step run ID. + step_run_metadata: The step run metadata. + """ + if not bool(getattr(self, "_metadata_enabled", True)): + return + from zenml.orchestrators.publish_utils import ( + publish_step_run_metadata as _pub_step_md, + ) + + _pub_step_md( + step_run_id=step_run_id, step_run_metadata=step_run_metadata + ) + + def publish_successful_step_run( + self, *, step_run_id: Any, output_artifact_ids: Any + ) -> None: + """Publish a successful step run. + + Args: + step_run_id: The step run ID. + output_artifact_ids: The output artifact IDs. + """ + from zenml.orchestrators.publish_utils import ( + publish_successful_step_run as _pub_step_success, + ) + + _pub_step_success( + step_run_id=step_run_id, output_artifact_ids=output_artifact_ids + ) + + def publish_failed_step_run(self, *, step_run_id: Any) -> None: + """Publish a failed step run. + + Args: + step_run_id: The step run ID. + """ + from zenml.orchestrators.publish_utils import ( + publish_failed_step_run as _pub_step_failed, + ) + + _pub_step_failed(step_run_id) + + def flush(self) -> None: + """Ensure all queued updates are sent.""" + + def on_step_end(self) -> None: + """Optional hook when a step finishes execution.""" + + def shutdown(self) -> None: + """Optional shutdown hook for runtime lifecycles.""" + + def get_metrics(self) -> Dict[str, Any]: + """Optional runtime metrics for observability. + + Returns: + Dictionary of runtime metrics; empty by default. + """ + return {} + + # --- Flush behavior --- + def should_flush_on_step_end(self) -> bool: + """Whether the runner should call flush() at step end. + + Implementations may override to disable flush for non-blocking serving. + + Returns: + True to flush on step end; False otherwise. + """ + return True diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index c66f60f75b2..c913d982db5 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -317,10 +317,19 @@ def main() -> None: for owner_reference in owner_references: owner_reference.controller = False + # Build a runtime for request factory (batch context) + try: + from zenml.execution.factory import get_runtime + + _runtime = get_runtime(serving=False, memory_only=False) + except Exception: + _runtime = None + step_run_request_factory = StepRunRequestFactory( deployment=deployment, pipeline_run=pipeline_run, stack=active_stack, + runtime=_runtime, ) step_runs = {} diff --git a/src/zenml/models/v2/core/pipeline_deployment.py b/src/zenml/models/v2/core/pipeline_deployment.py index 947185e7c7b..40085ade47c 100644 --- a/src/zenml/models/v2/core/pipeline_deployment.py +++ b/src/zenml/models/v2/core/pipeline_deployment.py @@ -75,6 +75,24 @@ class PipelineDeploymentBase(BaseZenModel): default=None, title="The pipeline spec of the deployment.", ) + # Canonical capture fields (single source of truth at runtime) + capture_memory_only: bool = Field( + default=False, + title="Serving-only: execute in memory without persistence.", + ) + capture_code: bool = Field( + default=True, title="Capture code/source/docstrings in metadata." + ) + capture_logs: bool = Field(default=True, title="Persist step logs.") + capture_metadata: bool = Field( + default=True, title="Publish run/step metadata." + ) + capture_visualizations: bool = Field( + default=True, title="Persist artifact visualizations." + ) + capture_metrics: bool = Field( + default=True, title="Emit runtime metrics (realtime)." + ) @property def should_prevent_build_reuse(self) -> bool: @@ -165,6 +183,24 @@ class PipelineDeploymentResponseMetadata(ProjectScopedResponseMetadata): default=None, title="Optional path where the code is stored in the artifact store.", ) + # Canonical capture fields (mirrored on response) + capture_memory_only: bool = Field( + default=False, + title="Serving-only: execute in memory without persistence.", + ) + capture_code: bool = Field( + default=True, title="Capture code/source/docstrings in metadata." + ) + capture_logs: bool = Field(default=True, title="Persist step logs.") + capture_metadata: bool = Field( + default=True, title="Publish run/step metadata." + ) + capture_visualizations: bool = Field( + default=True, title="Persist artifact visualizations." + ) + capture_metrics: bool = Field( + default=True, title="Emit runtime metrics (realtime)." + ) pipeline: Optional[PipelineResponse] = Field( default=None, title="The pipeline associated with the deployment." diff --git a/src/zenml/orchestrators/run_entity_manager.py b/src/zenml/orchestrators/run_entity_manager.py new file mode 100644 index 00000000000..229e6754907 --- /dev/null +++ b/src/zenml/orchestrators/run_entity_manager.py @@ -0,0 +1,189 @@ +"""Run entity manager scaffolding for unified execution. + +Abstracts creation and finalization of pipeline/step runs so we can plug in +either DB-backed behavior or stubbed in-memory entities for memory-only runs. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Protocol, Tuple, cast + + +class RunEntityManager(Protocol): + """Protocol for managing pipeline/step run entities.""" + + def create_or_reuse_run(self) -> Tuple[Any, bool]: + """Create or reuse a pipeline run entity. + + Returns: + A tuple of (pipeline_run, was_created). + """ + + def create_step_run(self, request: Any) -> Any: + """Create a step run entity. + + Args: + request: StepRunRequest-like object. + + Returns: + A step run entity. + """ + + def finalize_step_run_success( + self, step_run_id: Any, outputs: Any + ) -> None: + """Mark a step run successful.""" + + def finalize_step_run_failed(self, step_run_id: Any) -> None: + """Mark a step run failed.""" + + +@dataclass +class DefaultRunEntityManager: + """Placeholder for DB-backed manager (to be wired in Phase 2).""" + + launcher: Any + + def create_or_reuse_run(self) -> Tuple[Any, bool]: + """Create or reuse a pipeline run entity. + + Returns: + A tuple of (pipeline_run, was_created). + """ + return cast(Tuple[Any, bool], self.launcher._create_or_reuse_run()) + + def create_step_run(self, request: Any) -> Any: + """Create a step run entity. + + Args: + request: StepRunRequest-like object. + + Returns: + A step run entity. + """ + from zenml.client import Client + + return Client().zen_store.create_run_step(request) + + def finalize_step_run_success( + self, step_run_id: Any, outputs: Any + ) -> None: + """Mark a step run successful. + + Args: + step_run_id: The step run ID. + outputs: The outputs of the step run. + """ + # Defer to runtime publish for now. + return None + + def finalize_step_run_failed(self, step_run_id: Any) -> None: + """Mark a step run failed. + + Args: + step_run_id: The step run ID. + """ + # Defer to runtime publish for now. + return None + + +@dataclass +class MemoryRunEntityManager: + """Stubbed manager for memory-only execution (Phase 2 wiring).""" + + launcher: Any + + def create_or_reuse_run(self) -> Tuple[Any, bool]: + """Create or reuse a pipeline run entity. + + Returns: + A tuple of (pipeline_run, was_created). + """ + # Build a minimal pipeline run stub compatible with StepRunner expectations + run_id = self.launcher._orchestrator_run_id # noqa: SLF001 + + @dataclass + class _PRCfg: + tags: Any = None + enable_step_logs: Any = False + enable_artifact_metadata: Any = False + enable_artifact_visualization: Any = False + + @dataclass + class _PipelineRunStub: + id: str + name: str + model_version: Any = None + pipeline: Any = None + config: Any = field(default_factory=_PRCfg) + + return _PipelineRunStub(id=run_id, name=run_id), True + + def create_step_run(self, request: Any) -> Any: + """Create a step run entity. + + Args: + request: StepRunRequest-like object. + + Returns: + A step run entity. + """ + # Return a minimal step run stub + run_id = self.launcher._orchestrator_run_id # noqa: SLF001 + step_name = self.launcher._step_name # noqa: SLF001 + + @dataclass + class _StatusStub: + is_finished: bool = False + + @dataclass + class _StepRunStub: + id: str + name: str + model_version: Any = None + logs: Optional[Any] = None + status: Any = field(default_factory=_StatusStub) + outputs: Dict[str, Any] = None # type: ignore[assignment] + regular_inputs: Dict[str, Any] = None # type: ignore[assignment] + + # Minimal config for template substitutions used by StepRunner + @dataclass + class _Cfg: + # Step-level toggles are optional and may be None + enable_step_logs: Optional[bool] = None + enable_artifact_metadata: Optional[bool] = None + enable_artifact_visualization: Optional[bool] = None + substitutions: Dict[str, str] = None # type: ignore[assignment] + + config: Any = field(default_factory=_Cfg) + + def __post_init__(self) -> None: # noqa: D401 + self.outputs = {} + self.regular_inputs = {} + # Default to empty substitutions mapping + try: + self.config.substitutions = {} + except Exception: + pass + + return _StepRunStub(id=run_id, name=step_name) + + def finalize_step_run_success( + self, step_run_id: Any, outputs: Any + ) -> None: + """Mark a step run successful. + + Args: + step_run_id: The step run ID. + outputs: The outputs of the step run. + """ + return None + + def finalize_step_run_failed(self, step_run_id: Any) -> None: + """Mark a step run failed. + + Args: + step_run_id: The step run ID. + """ + return None diff --git a/src/zenml/orchestrators/runtime_manager.py b/src/zenml/orchestrators/runtime_manager.py new file mode 100644 index 00000000000..69ee02f01f5 --- /dev/null +++ b/src/zenml/orchestrators/runtime_manager.py @@ -0,0 +1,84 @@ +"""Runtime manager for unified runtime-driven execution paths. + +Provides helpers to reuse a shared runtime instance across all steps of a +single serving request (e.g., MemoryStepRuntime for memory-only execution), +and utilities to reset per-run state when the request completes. +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import Optional + +from zenml.execution.memory_runtime import MemoryStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime + +# Shared runtime context for the lifetime of a single request. +_shared_runtime: ContextVar[Optional[BaseStepRuntime]] = ContextVar( + "zenml_shared_runtime", default=None +) + + +def set_shared_runtime(runtime: BaseStepRuntime) -> None: + """Set a runtime instance to be reused across steps for the current request. + + Args: + runtime: The runtime instance to set. + """ + _shared_runtime.set(runtime) + + +def get_shared_runtime() -> Optional[BaseStepRuntime]: + """Get the shared runtime instance for the current request, if any. + + Returns: + The shared runtime instance for the current request, if any. + """ + return _shared_runtime.get() + + +def clear_shared_runtime() -> None: + """Clear the shared runtime instance for the current request. + + Returns: + The shared runtime instance for the current request, if any. + """ + _shared_runtime.set(None) + + +def get_or_create_shared_memory_runtime() -> MemoryStepRuntime: + """Get or create a shared MemoryStepRuntime for the current request. + + Returns: + The shared runtime instance for the current request, if any. + """ + rt = _shared_runtime.get() + if isinstance(rt, MemoryStepRuntime): + return rt + mem = MemoryStepRuntime() + set_shared_runtime(mem) + return mem + + +def reset_memory_runtime_for_run(run_id: str) -> None: + """Reset per-run memory state on the shared memory runtime if present. + + Args: + run_id: The run ID. + """ + rt = _shared_runtime.get() + if isinstance(rt, MemoryStepRuntime): + try: + rt.reset(run_id) + except Exception as e: + # Best-effort cleanup; log at debug level and continue + try: + from zenml.logger import get_logger + + get_logger(__name__).debug( + "Ignoring error during memory runtime reset for run %s: %s", + run_id, + e, + ) + except Exception: + pass diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index ebda46f608d..7f5111d7c35 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -13,13 +13,15 @@ # permissions and limitations under the License. """Class to launch (run directly or using a step operator) steps.""" +import json +import os import signal import time from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple from zenml.client import Client -from zenml.config.step_configurations import Step +from zenml.config.step_configurations import Step, StepConfiguration from zenml.config.step_run_info import StepRunInfo from zenml.constants import ( ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, @@ -29,6 +31,8 @@ from zenml.enums import ExecutionStatus from zenml.environment import get_run_environment_dict from zenml.exceptions import RunInterruptedException, RunStoppedException +from zenml.execution.factory import get_runtime +from zenml.execution.memory_runtime import MemoryStepRuntime from zenml.logger import get_logger from zenml.logging import step_logging from zenml.models import ( @@ -41,7 +45,16 @@ from zenml.models.v2.core.step_run import StepRunInputResponse from zenml.orchestrators import output_utils, publish_utils, step_run_utils from zenml.orchestrators import utils as orchestrator_utils +from zenml.orchestrators.run_entity_manager import ( + DefaultRunEntityManager, + MemoryRunEntityManager, + RunEntityManager, +) +from zenml.orchestrators.runtime_manager import ( + get_or_create_shared_memory_runtime, +) from zenml.orchestrators.step_runner import StepRunner +from zenml.orchestrators.utils import is_serving_context from zenml.stack import Stack from zenml.utils import exception_utils, string_utils from zenml.utils.time_utils import utc_now @@ -135,6 +148,90 @@ def __init__( self._step_run: Optional[StepRunResponse] = None self._setup_signal_handlers() + # --- Serving helpers --- + def _validate_and_merge_request_params( + self, + req_params: Dict[str, Any], + effective_step_config: StepConfiguration, + ) -> Dict[str, Any]: + """Safely merge request parameters with allowlist and light validation. + + Only keys already declared in the pipeline parameters are merged. + Performs simple type-coercion against defaults where possible and + applies size limits to avoid oversized payloads. + + TODO(beta->prod): derive expected types from the pipeline entrypoint + annotations (or a generated parameter schema) instead of the current + defaults-based heuristic; add a total payload size limit. + + Args: + req_params: Raw parameters dictionary from the request. + effective_step_config: The current effective step configuration. + + Returns: + Merged and validated parameters dictionary. + """ + if not req_params: + return effective_step_config.parameters or {} + + declared = set((effective_step_config.parameters or {}).keys()) + allowed = {k: v for k, v in req_params.items() if k in declared} + dropped = set(req_params.keys()) - declared + if dropped: + logger.warning( + "Dropping unknown request parameters: %s", sorted(dropped) + ) + + validated: Dict[str, Any] = {} + for key, value in allowed.items(): + # Size limits + try: + if isinstance(value, str) and len(value) > 10_000: + logger.warning( + "Dropping oversized string parameter '%s' (%s chars)", + key, + len(value), + ) + continue + if ( + isinstance(value, (list, dict)) + and len(str(value)) > 50_000 + ): + logger.warning( + "Dropping oversized collection parameter '%s'", key + ) + continue + except Exception: + # If size introspection fails, keep conservative and drop + logger.warning( + "Dropping parameter '%s' due to size check error", key + ) + continue + + # Type coercion against defaults, if present + try: + defaults = effective_step_config.parameters or {} + if key in defaults and defaults[key] is not None: + expected_t = type(defaults[key]) + if not isinstance(value, expected_t): + try: + value = expected_t(value) # best-effort coercion + except Exception: + logger.warning( + "Type mismatch for parameter '%s', dropping", + key, + ) + continue + except Exception: + # On any error, accept original value (already allowlisted) + pass + + validated[key] = value + + merged = dict(effective_step_config.parameters or {}) + merged.update(validated) + return merged + def _setup_signal_handlers(self) -> None: """Set up signal handlers for graceful shutdown, chaining previous handlers.""" try: @@ -169,14 +266,21 @@ def signal_handler(signum: int, frame: Any) -> None: client = Client() pipeline_run = None - if self._step_run: + # Memory-only stubs do not have a pipeline_run_id; handle gracefully + if self._step_run and hasattr( + self._step_run, "pipeline_run_id" + ): pipeline_run = client.get_pipeline_run( self._step_run.pipeline_run_id ) + elif self._step_run is None: + raise RunInterruptedException( + "The execution was interrupted and the step does not exist yet." + ) else: + # Memory-only: no server-side run to update; just signal interruption raise RunInterruptedException( - "The execution was interrupted and the step does not " - "exist yet." + "The execution was interrupted." ) if pipeline_run and pipeline_run.status in [ @@ -199,15 +303,23 @@ def signal_handler(signum: int, frame: Any) -> None: except Exception as e: raise RunInterruptedException(str(e)) finally: - # Chain to previous handler if it exists and is not default/ignore - if signum == signal.SIGTERM and callable( - self._prev_sigterm_handler - ): - self._prev_sigterm_handler(signum, frame) - elif signum == signal.SIGINT and callable( - self._prev_sigint_handler + # Chain to previous handler if it exists, not default/ignore, + # and not this handler to avoid recursion + prev = None + if signum == signal.SIGTERM: + prev = self._prev_sigterm_handler + elif signum == signal.SIGINT: + prev = self._prev_sigint_handler + if ( + prev + and prev not in (signal.SIG_DFL, signal.SIG_IGN) + and prev is not signal_handler ): - self._prev_sigint_handler(signum, frame) + try: + if callable(prev): + prev(signum, frame) + except Exception: + pass # Register handlers for common termination signals try: @@ -215,7 +327,7 @@ def signal_handler(signum: int, frame: Any) -> None: signal.signal(signal.SIGINT, signal_handler) except ValueError as e: # This happens when not in the main thread - logger.debug(f"Cannot register signal handlers: {e}") + logger.debug("Cannot register signal handlers: %s", e) # Continue without signal handling - the step will still run def launch(self) -> None: @@ -232,12 +344,96 @@ def launch(self) -> None: if self._deployment.pipeline_configuration.settings else None ) - pipeline_run, run_was_created = self._create_or_reuse_run() + + # Determine serving context and canonical capture flags + in_serving_ctx = is_serving_context() + mem_only_flag = bool( + getattr(self._deployment, "capture_memory_only", False) + ) + # Dev fallback: if canonical field missing or False, derive from typed capture + if not mem_only_flag: + try: + from zenml.capture.config import Capture as _Cap + + cap_typed = getattr( + self._deployment.pipeline_configuration, "capture", None + ) + if isinstance(cap_typed, _Cap) and bool( + getattr(cap_typed, "memory_only", False) + ): + mem_only_flag = True + except Exception: + pass + # memory_only applies only in serving; warn and ignore otherwise + memory_only = mem_only_flag if in_serving_ctx else False + if mem_only_flag and not in_serving_ctx: + logger.warning( + "memory_only=True configured but not in serving; ignoring." + ) + + metrics_enabled = bool( + getattr(self._deployment, "capture_metrics", True) + ) + if metrics_enabled is True: + try: + from zenml.capture.config import Capture as _Cap + + cap_typed = getattr( + self._deployment.pipeline_configuration, "capture", None + ) + if isinstance(cap_typed, _Cap): + metrics_enabled = bool(getattr(cap_typed, "metrics", True)) + except Exception: + pass + runtime = get_runtime( + serving=in_serving_ctx, + memory_only=memory_only, + metrics_enabled=metrics_enabled, + ) + # Store for later use + self._runtime = runtime + # Apply observability toggles to runtime + try: + setattr( + runtime, + "_metadata_enabled", + bool(getattr(self._deployment, "capture_metadata", True)), + ) + setattr( + runtime, + "_visualizations_enabled", + bool( + getattr(self._deployment, "capture_visualizations", True) + ), + ) + except Exception: + pass + + # Select entity manager and, if memory-only, set up shared runtime + is_memory_only_path = memory_only and in_serving_ctx + # Declare entity manager type for typing + entity_manager: RunEntityManager + if is_memory_only_path: + try: + shared = get_or_create_shared_memory_runtime() + self._runtime = shared + except Exception: + pass + logger.info( + "[Memory-only] Serving context detected; using in-process memory handoff (no runs/artifacts)." + ) + entity_manager = MemoryRunEntityManager(self) + else: + entity_manager = DefaultRunEntityManager(self) + + pipeline_run, run_was_created = entity_manager.create_or_reuse_run() + # No flush configuration: batch is blocking, serving is async by default # Enable or disable step logs storage if ( handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False) or tracking_disabled + or is_memory_only_path # never persist logs in memory-only ): step_logging_enabled = False else: @@ -249,6 +445,11 @@ def launch(self) -> None: logs_context = nullcontext() logs_model = None + # Apply observability toggle from canonical capture + capture_logs = bool(getattr(self._deployment, "capture_logs", True)) + if not capture_logs: + step_logging_enabled = False + if step_logging_enabled and not tracking_disabled: # Configure the logs logs_uri = step_logging.prepare_logs_uri( @@ -272,7 +473,8 @@ def launch(self) -> None: pipeline_run_metadata = self._stack.get_pipeline_run_metadata( run_id=pipeline_run.id ) - publish_utils.publish_pipeline_run_metadata( + runtime.start() + runtime.publish_pipeline_run_metadata( pipeline_run_id=pipeline_run.id, pipeline_run_metadata=pipeline_run_metadata, ) @@ -281,93 +483,167 @@ def launch(self) -> None: model_version=model_version ) + # Honor capture.code flag (default True) + code_enabled = bool(getattr(self._deployment, "capture_code", True)) + + # Prepare step run creation + if isinstance(entity_manager, DefaultRunEntityManager): request_factory = step_run_utils.StepRunRequestFactory( deployment=self._deployment, pipeline_run=pipeline_run, stack=self._stack, + runtime=runtime, + skip_code_capture=not code_enabled, ) step_run_request = request_factory.create_request( invocation_id=self._step_name ) step_run_request.logs = logs_model + # If this step has upstream dependencies and runtime uses non-blocking + # publishes, ensure previous step updates are flushed so input + # resolution via server succeeds. + if ( + self._step.spec.upstream_steps + and not runtime.should_flush_on_step_end() + ): try: - # Always populate request to ensure proper input/output flow + runtime.flush() + except Exception as e: + logger.debug( + "Non-blocking flush before input resolution failed: %s", e + ) + + try: + # Always populate request to ensure proper input/output flow + if isinstance(entity_manager, DefaultRunEntityManager): request_factory.populate_request(request=step_run_request) - # In no-capture mode, force fresh execution (bypass cache) - if tracking_disabled: + # In no-capture mode, force fresh execution (bypass cache) + if tracking_disabled: + if isinstance(entity_manager, DefaultRunEntityManager): step_run_request.original_step_run_id = None step_run_request.outputs = {} step_run_request.status = ExecutionStatus.RUNNING - except BaseException as e: - logger.exception(f"Failed preparing step `{self._step_name}`.") + except BaseException as e: + logger.exception("Failed preparing step `%s`.", self._step_name) + if isinstance(entity_manager, DefaultRunEntityManager): step_run_request.status = ExecutionStatus.FAILED step_run_request.end_time = utc_now() step_run_request.exception_info = ( exception_utils.collect_exception_information(e) ) - raise - finally: - # Always create real step run for proper input/output flow - step_run = Client().zen_store.create_run_step(step_run_request) - self._step_run = step_run - if not tracking_disabled and ( - model_version := step_run.model_version - ): - step_run_utils.log_model_version_dashboard_url( - model_version=model_version - ) + raise + finally: + # Create step run (DB-backed or stubbed) + if isinstance(entity_manager, DefaultRunEntityManager): + step_run = entity_manager.create_step_run(step_run_request) + else: + step_run = entity_manager.create_step_run(None) + self._step_run = step_run + if not tracking_disabled and ( + model_version := step_run.model_version + ): + step_run_utils.log_model_version_dashboard_url( + model_version=model_version + ) - if not step_run.status.is_finished: - logger.info(f"Step `{self._step_name}` has started.") + if not step_run.status.is_finished: + logger.info(f"Step `{self._step_name}` has started.") - try: - # here pass a forced save_to_file callable to be - # used as a dump function to use before starting - # the external jobs in step operators - if isinstance( - logs_context, - step_logging.PipelineLogsStorageContext, - ): - force_write_logs = ( - logs_context.storage.send_merge_event - ) - else: + try: + # here pass a forced save_to_file callable to be + # used as a dump function to use before starting + # the external jobs in step operators + if isinstance( + logs_context, + step_logging.PipelineLogsStorageContext, + ): + force_write_logs = logs_context.storage.send_merge_event + else: - def _bypass() -> None: - return None + def _bypass() -> None: + return None - force_write_logs = _bypass - self._run_step( - pipeline_run=pipeline_run, - step_run=step_run, - force_write_logs=force_write_logs, - ) - except RunStoppedException as e: - raise e - except BaseException as e: # noqa: E722 - logger.error( - "Failed to run step `%s`: %s", - self._step_name, - e, + force_write_logs = _bypass + self._run_step( + pipeline_run=pipeline_run, + step_run=step_run, + force_write_logs=force_write_logs, + ) + except RunStoppedException as e: + raise e + except BaseException as e: # noqa: E722 + logger.error( + "Failed to run step `%s`: %s", + self._step_name, + e, + ) + if not tracking_disabled: + # Delegate finalization to entity manager (DB-backed or no-op) + try: + entity_manager.finalize_step_run_failed(step_run.id) + except Exception: + try: + runtime.publish_failed_step_run( + step_run_id=step_run.id + ) + except Exception: + pass + if runtime.should_flush_on_step_end(): + runtime.flush() + else: + try: + getattr( + runtime, "check_async_errors", lambda: None + )() + except Exception: + pass + raise + else: + logger.info(f"Using cached version of step `{self._step_name}`.") + if not tracking_disabled: + if ( + model_version := step_run.model_version + or pipeline_run.model_version + ): + step_run_utils.link_output_artifacts_to_model_version( + artifacts=step_run.outputs, + model_version=model_version, ) - if not tracking_disabled: - publish_utils.publish_failed_step_run(step_run.id) - raise - else: + # Ensure any queued updates are flushed for cached path (if enabled) + if runtime.should_flush_on_step_end(): + runtime.flush() + else: + try: + getattr(runtime, "check_async_errors", lambda: None)() + except Exception: + pass + # Notify entity manager of successful completion (default no-op) + try: + entity_manager.finalize_step_run_success( + step_run.id, step_run.outputs + ) + except Exception: + pass + # Ensure runtime shutdown after launch + try: + metrics = {} + try: + metrics = runtime.get_metrics() or {} + except Exception: + metrics = {} + runtime.shutdown() + if metrics: logger.info( - f"Using cached version of step `{self._step_name}`." + "Runtime metrics: queued=%s processed=%s failed_total=%s queue_depth=%s", + metrics.get("queued"), + metrics.get("processed"), + metrics.get("failed_total"), + metrics.get("queue_depth"), ) - if not tracking_disabled: - if ( - model_version := step_run.model_version - or pipeline_run.model_version - ): - step_run_utils.link_output_artifacts_to_model_version( - artifacts=step_run.outputs, - model_version=model_version, - ) + except Exception as e: + logger.debug("Runtime shutdown/metrics retrieval error: %s", e) def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: """Creates a pipeline run or reuses an existing one. @@ -419,8 +695,6 @@ def _run_step( force_write_logs: The context for the step logs. """ # Create effective step config with serving overrides and no-capture optimizations - from zenml.orchestrators import utils as orchestrator_utils - effective_step_config = self._step.config.model_copy(deep=True) # In no-capture mode, disable caching and step operators for speed @@ -440,6 +714,29 @@ def _run_step( } ) + # Merge request-level parameters in serving (applies to all runtimes) + runtime = getattr(self, "_runtime", None) + if is_serving_context(): + try: + req_env = os.getenv("ZENML_SERVING_REQUEST_PARAMS") + req_params = json.loads(req_env) if req_env else {} + if req_params: + merged = self._validate_and_merge_request_params( + req_params, effective_step_config + ) + effective_step_config = effective_step_config.model_copy( + update={"parameters": merged} + ) + try: + logger.info( + "[Serving] Request parameters merged into step config: %s", + sorted(list(req_params.keys())), + ) + except Exception: + pass + except Exception: + pass + # Prepare step run information with effective config step_run_info = StepRunInfo( config=effective_step_config, @@ -451,10 +748,21 @@ def _run_step( force_write_logs=force_write_logs, ) - # Always prepare output URIs for proper artifact flow - output_artifact_uris = output_utils.prepare_output_artifact_uris( - step_run=step_run, stack=self._stack, step=self._step - ) + # Prepare output URIs + if isinstance(runtime, MemoryStepRuntime): + # Build memory:// URIs from declared outputs (no FS writes) + run_id = str( + getattr(pipeline_run, "id", self._orchestrator_run_id) + ) + output_names = list(self._step.config.outputs.keys()) + output_artifact_uris = { + name: f"memory://{run_id}/{self._step_name}/{name}" + for name in output_names + } + else: + output_artifact_uris = output_utils.prepare_output_artifact_uris( + step_run=step_run, stack=self._stack, step=self._step + ) # Run the step. start_time = time.time() @@ -469,17 +777,27 @@ def _run_step( step_run_info=step_run_info, ) else: + # Resolve inputs via runtime in memory-only; otherwise use server-resolved inputs + if isinstance(runtime, MemoryStepRuntime): + input_artifacts = runtime.resolve_step_inputs( + step=self._step, pipeline_run=pipeline_run + ) + else: + input_artifacts = step_run.regular_inputs + self._run_step_without_step_operator( pipeline_run=pipeline_run, step_run=step_run, step_run_info=step_run_info, - input_artifacts=step_run.regular_inputs, + input_artifacts=input_artifacts, output_artifact_uris=output_artifact_uris, ) except: # noqa: E722 - output_utils.remove_artifact_dirs( - artifact_uris=list(output_artifact_uris.values()) - ) + # Best-effort cleanup only for filesystem URIs + if not isinstance(runtime, MemoryStepRuntime): + output_utils.remove_artifact_dirs( + artifact_uris=list(output_artifact_uris.values()) + ) raise duration = time.time() - start_time @@ -488,6 +806,24 @@ def _run_step( f"`{string_utils.get_human_readable_time(duration)}`." ) + # If runtime is non-blocking, consider a best-effort flush at step end. + # - If there are downstream steps, flush to ensure server has updates + # - If no downstream (leaf step), flush in serving to publish outputs so UI shows previews immediately + if runtime is not None and not runtime.should_flush_on_step_end(): + has_downstream = any( + self._step_name in cfg.spec.upstream_steps + for name, cfg in self._deployment.step_configurations.items() + ) + should_flush = has_downstream or is_serving_context() + if should_flush: + try: + runtime.flush() + except Exception as e: + logger.debug( + "Non-blocking runtime flush after step finish failed: %s", + e, + ) + def _run_step_with_step_operator( self, step_operator_name: Optional[str], @@ -520,6 +856,7 @@ def _run_step_with_step_operator( environment.update(secrets) environment[ENV_ZENML_STEP_OPERATOR] = "True" + # No capture mode propagation; runtime behavior derived from context logger.info( "Using step operator `%s` to run step `%s`.", step_operator.name, @@ -548,7 +885,11 @@ def _run_step_without_step_operator( input_artifacts: The input artifact versions of the current step. output_artifact_uris: The output artifact URIs of the current step. """ - runner = StepRunner(step=self._step, stack=self._stack) + # Use runtime determined at launch + runtime = getattr(self, "_runtime", None) + runner = StepRunner( + step=self._step, stack=self._stack, runtime=runtime + ) runner.run( pipeline_run=pipeline_run, step_run=step_run, diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 808ab121769..aaf34583151 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -20,6 +20,8 @@ from zenml.config.step_configurations import Step from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH from zenml.enums import ExecutionStatus +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime from zenml.logger import get_logger from zenml.model.utils import link_artifact_version_to_model_version from zenml.models import ( @@ -30,7 +32,7 @@ StepRunRequest, StepRunResponse, ) -from zenml.orchestrators import cache_utils, input_utils, utils +from zenml.orchestrators import utils from zenml.stack import Stack from zenml.utils import pagination_utils from zenml.utils.time_utils import utc_now @@ -46,6 +48,8 @@ def __init__( deployment: "PipelineDeploymentResponse", pipeline_run: "PipelineRunResponse", stack: "Stack", + runtime: Optional[BaseStepRuntime] = None, + skip_code_capture: bool = False, ) -> None: """Initialize the object. @@ -54,10 +58,14 @@ def __init__( pipeline_run: The pipeline run for which to create step run requests. stack: The stack on which the pipeline run is happening. + runtime: The runtime to use for the step run requests. + skip_code_capture: Whether to skip code/docstring capture. """ self.deployment = deployment self.pipeline_run = pipeline_run self.stack = stack + self.runtime: BaseStepRuntime = runtime or DefaultStepRuntime() + self.skip_code_capture = skip_code_capture def has_caching_enabled(self, invocation_id: str) -> bool: """Check if the step has caching enabled. @@ -112,7 +120,7 @@ def populate_request( """ step = self.deployment.step_configurations[request.name] - input_artifacts = input_utils.resolve_step_inputs( + input_artifacts = self.runtime.resolve_step_inputs( step=step, pipeline_run=self.pipeline_run, step_runs=step_runs, @@ -122,7 +130,7 @@ def populate_request( name: [artifact.id] for name, artifact in input_artifacts.items() } - cache_key = cache_utils.generate_cache_key( + cache_key = self.runtime.compute_cache_key( step=step, input_artifacts=input_artifacts, artifact_store=self.stack.artifact_store, @@ -130,13 +138,14 @@ def populate_request( ) request.cache_key = cache_key - ( - docstring, - source_code, - ) = self._get_docstring_and_source_code(invocation_id=request.name) + if not self.skip_code_capture: + ( + docstring, + source_code, + ) = self._get_docstring_and_source_code(invocation_id=request.name) - request.docstring = docstring - request.source_code = source_code + request.docstring = docstring + request.source_code = source_code request.code_hash = step.config.parameters.get( CODE_HASH_PARAMETER_NAME ) @@ -147,7 +156,7 @@ def populate_request( ) if cache_enabled: - if cached_step_run := cache_utils.get_cached_step_run( + if cached_step_run := self.runtime.get_cached_step_run( cache_key=cache_key ): request.inputs = { diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 8cad58ad1f7..d8cd127d805 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -26,6 +26,7 @@ Optional, Tuple, Type, + cast, ) from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact @@ -40,6 +41,8 @@ ) from zenml.enums import ArtifactSaveType from zenml.exceptions import StepInterfaceError +from zenml.execution.default_runtime import DefaultStepRuntime +from zenml.execution.step_runtime import BaseStepRuntime from zenml.logger import get_logger from zenml.logging.step_logging import PipelineLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer @@ -49,7 +52,6 @@ ) from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, - publish_successful_step_run, step_exception_info, ) from zenml.orchestrators.utils import ( @@ -90,15 +92,24 @@ class StepRunner: """Class to run steps.""" - def __init__(self, step: "Step", stack: "Stack"): + def __init__( + self, + step: "Step", + stack: "Stack", + runtime: Optional[BaseStepRuntime] = None, + ): """Initializes the step runner. Args: step: The step to run. stack: The stack on which the step should run. + runtime: The runtime to use for the step run. """ self._step = step self._stack = stack + # Initialize runtime behind an opt-in flag to preserve behavior + # Always have a runtime to avoid branching; default to behavior-parity runtime + self._runtime: BaseStepRuntime = runtime or DefaultStepRuntime() @property def configuration(self) -> StepConfiguration: @@ -107,6 +118,10 @@ def configuration(self) -> StepConfiguration: Returns: The step configuration. """ + # Prefer effective config from step_run_info if available (serving overrides) + effective = getattr(self, "_step_run_info", None) + if effective: + return cast(StepConfiguration, effective.config) return self._step.config def run( @@ -194,6 +209,9 @@ def run( # Initialize the step context singleton StepContext._clear() + # Pass pipeline state if serving provided one + from zenml.orchestrators import utils as _orch_utils + step_context = StepContext( pipeline_run=pipeline_run, step_run=step_run, @@ -202,6 +220,7 @@ def run( output_artifact_configs={ k: v.artifact_config for k, v in output_annotations.items() }, + pipeline_state=_orch_utils.get_pipeline_state(), ) # Parse the inputs for the entrypoint function. @@ -213,6 +232,8 @@ def run( step_failed = False try: + if self._runtime is not None: + self._runtime.on_step_start() return_values = step_instance.call_entrypoint( **function_params ) @@ -253,10 +274,16 @@ def run( step_run_metadata = self._stack.get_step_run_metadata( info=step_run_info, ) - publish_step_run_metadata( - step_run_id=step_run_info.step_run_id, - step_run_metadata=step_run_metadata, - ) + if self._runtime is not None: + self._runtime.publish_step_run_metadata( + step_run_id=step_run_info.step_run_id, + step_run_metadata=step_run_metadata, + ) + else: + publish_step_run_metadata( + step_run_id=step_run_info.step_run_id, + step_run_metadata=step_run_metadata, + ) self._stack.cleanup_step_run( info=step_run_info, step_failed=step_failed ) @@ -302,14 +329,24 @@ def run( is_enabled_on_step=step_run_info.config.enable_artifact_visualization, is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization, ) - output_artifacts = self._store_output_artifacts( - output_data=output_data, - output_artifact_uris=output_artifact_uris, - output_materializers=output_materializers, - output_annotations=output_annotations, - artifact_metadata_enabled=artifact_metadata_enabled, - artifact_visualization_enabled=artifact_visualization_enabled, - ) + if self._runtime is not None: + output_artifacts = self._runtime.store_output_artifacts( + output_data=output_data, + output_artifact_uris=output_artifact_uris, + output_materializers=output_materializers, + output_annotations=output_annotations, + artifact_metadata_enabled=artifact_metadata_enabled, + artifact_visualization_enabled=artifact_visualization_enabled, + ) + else: + output_artifacts = self._store_output_artifacts( + output_data=output_data, + output_artifact_uris=output_artifact_uris, + output_materializers=output_materializers, + output_annotations=output_annotations, + artifact_metadata_enabled=artifact_metadata_enabled, + artifact_visualization_enabled=artifact_visualization_enabled, + ) if ( model_version := step_run.model_version @@ -336,10 +373,14 @@ def run( ] for output_name, artifact in output_artifacts.items() } - publish_successful_step_run( + self._runtime.publish_successful_step_run( step_run_id=step_run_info.step_run_id, output_artifact_ids=output_artifact_ids, ) + # Ensure updates are flushed at end of step unless disabled + self._runtime.on_step_end() + if self._runtime.should_flush_on_step_end(): + self._runtime.flush() def _evaluate_artifact_names_in_collections( self, @@ -441,9 +482,16 @@ def _parse_inputs( arg_type = resolve_type_annotation(arg_type) if arg in input_artifacts: - function_params[arg] = self._load_input_artifact( - input_artifacts[arg], arg_type - ) + if self._runtime is not None: + function_params[arg] = self._runtime.load_input_artifact( + artifact=input_artifacts[arg], + data_type=arg_type, + stack=self._stack, + ) + else: + function_params[arg] = self._load_input_artifact( + input_artifacts[arg], arg_type + ) elif arg in self.configuration.parameters: param_value = self.configuration.parameters[arg] # Pydantic bridging: convert dict to Pydantic model if possible diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index c9899938961..a3deffd1aa7 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -112,7 +112,7 @@ def is_tracking_enabled( - 'none' (case-insensitive) or False -> disable tracking - any other value or missing -> enable tracking - For serving, respects ZENML_SERVING_CAPTURE_DEFAULT when pipeline settings are absent. + Serving context does not change this; capture options are typed-only. Args: pipeline_settings: Pipeline configuration settings mapping, if any. @@ -121,27 +121,11 @@ def is_tracking_enabled( Whether tracking should be enabled. """ if not pipeline_settings: - # Check for serving default when no pipeline settings - import os - - serving_default = ( - os.getenv("ZENML_SERVING_CAPTURE_DEFAULT", "").strip().lower() - ) - if serving_default in {"none", "off", "false", "0", "disabled"}: - return False return True try: capture_value = pipeline_settings.get("capture") if capture_value is None: - # Check for serving default when capture setting is missing - import os - - serving_default = ( - os.getenv("ZENML_SERVING_CAPTURE_DEFAULT", "").strip().lower() - ) - if serving_default in {"none", "off", "false", "0", "disabled"}: - return False return True if isinstance(capture_value, bool): return capture_value @@ -176,7 +160,14 @@ def is_tracking_enabled( def is_tracking_disabled( pipeline_settings: Optional[Dict[str, Any]] = None, ) -> bool: - """True if tracking/persistence should be disabled completely.""" + """True if tracking/persistence should be disabled completely. + + Args: + pipeline_settings: Optional pipeline settings mapping. + + Returns: + True if tracking should be disabled, False otherwise. + """ return not is_tracking_enabled(pipeline_settings) @@ -187,17 +178,75 @@ def is_tracking_disabled( def tap_store_step_outputs(step_name: str, outputs: Dict[str, Any]) -> None: - """Store step outputs in the serve tap for in-memory handoff.""" + """Store step outputs in the serve tap for in-memory handoff. + + Args: + step_name: Name of the step producing outputs. + outputs: Mapping of output name to value. + """ current_tap = _serve_output_tap.get({}) current_tap[step_name] = outputs _serve_output_tap.set(current_tap) def tap_get_step_outputs(step_name: str) -> Optional[Dict[str, Any]]: - """Get step outputs from the serve tap.""" + """Get step outputs from the serve tap. + + Args: + step_name: Name of the step whose outputs to fetch. + + Returns: + Optional mapping of outputs for the step if present, else None. + """ return _serve_output_tap.get({}).get(step_name) +# Serving context marker +_serving_ctx: ContextVar[bool] = ContextVar("serving_ctx", default=False) + + +def set_serving_context(value: bool) -> None: + """Set whether the current execution is in a serving context. + + Args: + value: True if running inside the serving service, else False. + """ + _serving_ctx.set(bool(value)) + + +def is_serving_context() -> bool: + """Return True if running inside a serving context. + + Returns: + True if serving context is active, otherwise False. + """ + return _serving_ctx.get() + + +# Serve pipeline state context +_serve_pipeline_state: ContextVar[Optional[Any]] = ContextVar( + "serve_pipeline_state", default=None +) + + +def set_pipeline_state(state: Optional[Any]) -> None: + """Set pipeline state for serving context. + + Args: + state: Optional pipeline state object to associate with this request. + """ + _serve_pipeline_state.set(state) + + +def get_pipeline_state() -> Optional[Any]: + """Get pipeline state for serving context. + + Returns: + Optional pipeline state object if set, else None. + """ + return _serve_pipeline_state.get(None) + + def tap_clear() -> None: """Clear the serve tap for a fresh request.""" _serve_output_tap.set({}) diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index fdee3e5890e..cbc54bc988c 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -25,6 +25,7 @@ overload, ) +from zenml.capture.config import Capture from zenml.logger import get_logger if TYPE_CHECKING: @@ -63,6 +64,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Capture] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Callable[["F"], "Pipeline"]: ... @@ -85,6 +87,7 @@ def pipeline( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Capture] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]: """Decorator to create a pipeline. @@ -116,6 +119,7 @@ def pipeline( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. + capture: Capture policy for the pipeline (typed only). cache_policy: Cache policy for this pipeline. Returns: @@ -142,6 +146,7 @@ def inner_decorator(func: "F") -> "Pipeline": model=model, retry=retry, substitutions=substitutions, + capture=capture, cache_policy=cache_policy, ) diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index 91caab72cca..ea0d93c6217 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -42,6 +42,7 @@ from zenml import constants from zenml.analytics.enums import AnalyticsEvent from zenml.analytics.utils import track_handler +from zenml.capture.config import Capture from zenml.client import Client from zenml.config.compiler import Compiler from zenml.config.pipeline_configurations import ( @@ -149,6 +150,7 @@ def __init__( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Capture] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> None: """Initializes a pipeline. @@ -182,6 +184,7 @@ def __init__( model: configuration of the model in the Model Control Plane. retry: Retry configuration for the pipeline steps. substitutions: Extra placeholders to use in the name templates. + capture: Capture configuration for the pipeline (typed only). cache_policy: Cache policy for this pipeline. """ self._invocations: Dict[str, StepInvocation] = {} @@ -208,6 +211,7 @@ def __init__( model=model, retry=retry, substitutions=substitutions, + capture=capture, cache_policy=cache_policy, ) self.entrypoint = entrypoint @@ -334,6 +338,7 @@ def configure( parameters: Optional[Dict[str, Any]] = None, merge: bool = True, substitutions: Optional[Dict[str, str]] = None, + capture: Optional[Capture] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Self: """Configures the pipeline. @@ -381,6 +386,7 @@ def configure( retry: Retry configuration for the pipeline steps. parameters: input parameters for the pipeline. substitutions: Extra placeholders to use in the name templates. + capture: Capture configuration for the pipeline (typed only). cache_policy: Cache policy for this pipeline. Returns: @@ -411,6 +417,9 @@ def configure( # merges dicts tags = self._configuration.tags + tags + # Directly store typed capture config + cap_norm = capture + values = dict_utils.remove_none_values( { "enable_cache": enable_cache, @@ -429,6 +438,7 @@ def configure( "retry": retry, "parameters": parameters, "substitutions": substitutions, + "capture": cap_norm, "cache_policy": cache_policy, } ) diff --git a/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_endpoints.py b/src/zenml/zen_stores/migrations/versions/a848b2980c54_pipeline_endpoint_capture.py similarity index 65% rename from src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_endpoints.py rename to src/zenml/zen_stores/migrations/versions/a848b2980c54_pipeline_endpoint_capture.py index 694dc0998c9..10cbd6168a2 100644 --- a/src/zenml/zen_stores/migrations/versions/0d69e308846a_add_pipeline_endpoints.py +++ b/src/zenml/zen_stores/migrations/versions/a848b2980c54_pipeline_endpoint_capture.py @@ -1,8 +1,8 @@ -"""add pipeline endpoints [0d69e308846a]. +"""pipeline endpoint + capture [a848b2980c54]. -Revision ID: 0d69e308846a -Revises: 83ef3cb746a5 -Create Date: 2025-08-26 10:30:52.737833 +Revision ID: a848b2980c54 +Revises: aae4eed923b5 +Create Date: 2025-09-07 18:04:15.320419 """ @@ -12,8 +12,8 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. -revision = "0d69e308846a" -down_revision = "83ef3cb746a5" +revision = "a848b2980c54" +down_revision = "aae4eed923b5" branch_labels = None depends_on = None @@ -36,7 +36,9 @@ def upgrade() -> None: sa.Column("auth_key", sa.TEXT(), nullable=True), sa.Column( "endpoint_metadata", - sa.String(length=16777215).with_variant(mysql.MEDIUMTEXT, "mysql"), + sa.String(length=16777215).with_variant( + mysql.MEDIUMTEXT(), "mysql" + ), nullable=False, ), sa.Column( @@ -45,18 +47,18 @@ def upgrade() -> None: nullable=True, ), sa.Column("deployer_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), - sa.ForeignKeyConstraint( - ["pipeline_deployment_id"], - ["pipeline_deployment.id"], - name="fk_pipeline_endpoint_pipeline_deployment_id_pipeline_deployment", - ondelete="SET NULL", - ), sa.ForeignKeyConstraint( ["deployer_id"], ["stack_component.id"], name="fk_pipeline_endpoint_deployer_id_stack_component", ondelete="SET NULL", ), + sa.ForeignKeyConstraint( + ["pipeline_deployment_id"], + ["pipeline_deployment.id"], + name="fk_pipeline_endpoint_pipeline_deployment_id_pipeline_deployment", + ondelete="SET NULL", + ), sa.ForeignKeyConstraint( ["project_id"], ["project.id"], @@ -76,11 +78,39 @@ def upgrade() -> None: name="unique_pipeline_endpoint_name_in_project", ), ) + with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op: + batch_op.add_column( + sa.Column("capture_memory_only", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_code", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_logs", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_metadata", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_visualizations", sa.Boolean(), nullable=False) + ) + batch_op.add_column( + sa.Column("capture_metrics", sa.Boolean(), nullable=False) + ) + # ### end Alembic commands ### def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_deployment", schema=None) as batch_op: + batch_op.drop_column("capture_metrics") + batch_op.drop_column("capture_visualizations") + batch_op.drop_column("capture_metadata") + batch_op.drop_column("capture_logs") + batch_op.drop_column("capture_code") + batch_op.drop_column("capture_memory_only") + op.drop_table("pipeline_endpoint") # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py index 7641fea8df9..d2321d8afe9 100644 --- a/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_deployment_schemas.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Sequence from uuid import UUID -from sqlalchemy import TEXT, Column, String, UniqueConstraint +from sqlalchemy import TEXT, Boolean, Column, String, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.orm import joinedload, object_session from sqlalchemy.sql.base import ExecutableOption @@ -87,6 +87,26 @@ class PipelineDeploymentSchema(BaseSchema, table=True): ) code_path: Optional[str] = Field(nullable=True) + # Canonical capture fields + capture_memory_only: bool = Field( + sa_column=Column(Boolean, nullable=False, default=False), default=False + ) + capture_code: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_logs: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_metadata: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_visualizations: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + capture_metrics: bool = Field( + sa_column=Column(Boolean, nullable=False, default=True), default=True + ) + # Foreign keys user_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, @@ -319,6 +339,14 @@ def from_request( if request.pipeline_spec else None, code_path=request.code_path, + capture_memory_only=getattr(request, "capture_memory_only", False), + capture_code=getattr(request, "capture_code", True), + capture_logs=getattr(request, "capture_logs", True), + capture_metadata=getattr(request, "capture_metadata", True), + capture_visualizations=getattr( + request, "capture_visualizations", True + ), + capture_metrics=getattr(request, "capture_metrics", True), ) def to_model( @@ -390,6 +418,12 @@ def to_model( else None, code_path=self.code_path, template_id=self.template_id, + capture_memory_only=self.capture_memory_only, + capture_code=self.capture_code, + capture_logs=self.capture_logs, + capture_metadata=self.capture_metadata, + capture_visualizations=self.capture_visualizations, + capture_metrics=self.capture_metrics, ) resources = None diff --git a/tests/unit/execution/test_default_runtime_metadata_toggle.py b/tests/unit/execution/test_default_runtime_metadata_toggle.py new file mode 100644 index 00000000000..f85666600f6 --- /dev/null +++ b/tests/unit/execution/test_default_runtime_metadata_toggle.py @@ -0,0 +1,40 @@ +"""Unit tests for DefaultStepRuntime metadata/visualization toggles.""" + +from types import SimpleNamespace + +from zenml.execution.default_runtime import DefaultStepRuntime + + +def test_publish_metadata_skips_when_disabled(monkeypatch): + """Test that metadata is not published when disabled.""" + rt = DefaultStepRuntime() + setattr(rt, "_metadata_enabled", False) + + called = {"run": 0, "step": 0} + + def _pub_run_md(*a, **k): + """Mock publish pipeline run metadata.""" + called["run"] += 1 + + def _pub_step_md(*a, **k): + """Mock publish step run metadata.""" + called["step"] += 1 + + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_pipeline_run_metadata", + _pub_run_md, + ) + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_step_run_metadata", + _pub_step_md, + ) + + rt.publish_pipeline_run_metadata( + pipeline_run_id=SimpleNamespace(), pipeline_run_metadata={} + ) + rt.publish_step_run_metadata( + step_run_id=SimpleNamespace(), step_run_metadata={} + ) + + assert called["run"] == 0 + assert called["step"] == 0 diff --git a/tests/unit/execution/test_memory_runtime.py b/tests/unit/execution/test_memory_runtime.py new file mode 100644 index 00000000000..fda2314743a --- /dev/null +++ b/tests/unit/execution/test_memory_runtime.py @@ -0,0 +1,69 @@ +"""Unit tests for MemoryStepRuntime instance-scoped isolation.""" + +from types import SimpleNamespace + +from zenml.execution.memory_runtime import MemoryStepRuntime + + +def test_memory_runtime_instance_isolated_store(monkeypatch): + """Each runtime instance isolates values by run id; no cross leakage.""" + # Create two independent runtimes (instance-scoped stores) + rt1 = MemoryStepRuntime() + rt2 = MemoryStepRuntime() + + # Patch get_step_context to return minimal stubs + class _Ctx: + def __init__(self, run_id: str, step_name: str): + self.pipeline_run = SimpleNamespace(id=run_id) + self.step_run = SimpleNamespace(name=step_name) + + def get_output_metadata(self, name: str): + return {} + + def get_output_tags(self, name: str): + return [] + + monkeypatch.setattr( + "zenml.execution.step_runtime.get_step_context", + lambda: _Ctx("run-1", "s1"), + ) + + # Store with rt1 + outputs = {"out": 123} + handles1 = rt1.store_output_artifacts( + output_data=outputs, + output_materializers={"out": ()}, + output_artifact_uris={"out": "memory://run-1/s1/out"}, + output_annotations={"out": SimpleNamespace(artifact_config=None)}, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + h1 = handles1["out"] + + # Switch context for rt2 + monkeypatch.setattr( + "zenml.execution.step_runtime.get_step_context", + lambda: _Ctx("run-2", "s2"), + ) + handles2 = rt2.store_output_artifacts( + output_data={"out": 999}, + output_materializers={"out": ()}, + output_artifact_uris={"out": "memory://run-2/s2/out"}, + output_annotations={"out": SimpleNamespace(artifact_config=None)}, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + h2 = handles2["out"] + + # rt1 should load its own value + v1 = rt1.load_input_artifact(artifact=h1, data_type=int, stack=None) + assert v1 == 123 + + # rt2 should load its own value + v2 = rt2.load_input_artifact(artifact=h2, data_type=int, stack=None) + assert v2 == 999 + + # rt1 should NOT see rt2 value + assert ( + rt1.load_input_artifact(artifact=h2, data_type=int, stack=None) is None + ) diff --git a/tests/unit/execution/test_realtime_runtime.py b/tests/unit/execution/test_realtime_runtime.py new file mode 100644 index 00000000000..d1223bf02ed --- /dev/null +++ b/tests/unit/execution/test_realtime_runtime.py @@ -0,0 +1,41 @@ +"""Unit tests for RealtimeStepRuntime queue/backpressure and sweep.""" + +import queue +from types import SimpleNamespace + +from zenml.execution.realtime_runtime import RealtimeStepRuntime + + +def test_realtime_queue_full_inline_fallback(monkeypatch): + """When queue is full, publish events are processed inline as fallback.""" + rt = RealtimeStepRuntime(ttl_seconds=1, max_entries=8) + + # Replace queue with a tiny one and fill it + rt._q = queue.Queue(maxsize=1) # type: ignore[attr-defined] + rt._q.put(("dummy", (), {})) # fill once + + called = {"step": 0} + + def _pub_step_run_metadata(*args, **kwargs): + called["step"] += 1 + + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_step_run_metadata", + _pub_step_run_metadata, + ) + + # This put_nowait should hit Full and process inline + rt.publish_step_run_metadata( + step_run_id=SimpleNamespace(), step_run_metadata={} + ) + assert called["step"] == 1 + + +def test_realtime_sweep_expired_no_keyerror(): + """Expired cache entries are swept safely without KeyError races.""" + rt = RealtimeStepRuntime(ttl_seconds=0, max_entries=8) + # Insert an expired cache entry manually + with rt._lock: # type: ignore[attr-defined] + rt._cache["k1"] = ("v", 0.0) # type: ignore[attr-defined] + # Should not raise + rt._sweep_expired() diff --git a/tests/unit/execution/test_step_runtime_artifact_write.py b/tests/unit/execution/test_step_runtime_artifact_write.py new file mode 100644 index 00000000000..c1658e544b4 --- /dev/null +++ b/tests/unit/execution/test_step_runtime_artifact_write.py @@ -0,0 +1,78 @@ +"""Unit test for defensive artifact write behavior (retry + validate).""" + +from types import SimpleNamespace + +from zenml.execution.default_runtime import DefaultStepRuntime + + +def test_artifact_write_retry_and_validate(monkeypatch): + """First batch create fails, retry succeeds; responses length validated.""" + rt = DefaultStepRuntime() + + # Patch helpers used to build requests + monkeypatch.setattr( + "zenml.orchestrators.publish_utils.publish_successful_step_run", + lambda *a, **k: None, + ) + + # Minimal step context stub + class _Ctx: + def __init__(self): + self.pipeline_run = SimpleNamespace( + config=SimpleNamespace(tags=None), pipeline=None + ) + self.step_run = SimpleNamespace(name="step") + + def get_output_metadata(self, name: str): + return {} + + def get_output_tags(self, name: str): + return [] + + monkeypatch.setattr( + "zenml.execution.step_runtime.get_step_context", + lambda: _Ctx(), + ) + + # Patch request preparation to avoid heavy imports + monkeypatch.setattr( + "zenml.execution.step_runtime._store_artifact_data_and_prepare_request", + lambda **k: {"req": k}, + ) + # Patch materializer selection + monkeypatch.setattr( + "zenml.execution.step_runtime.materializer_utils.select_materializer", + lambda data_type, materializer_classes: object, + ) + monkeypatch.setattr( + "zenml.execution.step_runtime.source_utils.load_and_validate_class", + lambda *a, **k: object, + ) + + calls = {"attempts": 0} + + class _Client: + class _Store: + def batch_create_artifact_versions(self, reqs): + calls["attempts"] += 1 + if calls["attempts"] == 1: + raise RuntimeError("transient") + # Return matching length list + return [SimpleNamespace(id=i) for i in range(len(reqs))] + + zen_store = _Store() + + monkeypatch.setattr( + "zenml.execution.step_runtime.Client", lambda: _Client() + ) + + res = rt.store_output_artifacts( + output_data={"out": 1}, + output_materializers={"out": ()}, + output_artifact_uris={"out": "uri://out"}, + output_annotations={"out": SimpleNamespace(artifact_config=None)}, + artifact_metadata_enabled=False, + artifact_visualization_enabled=False, + ) + assert "out" in res + assert calls["attempts"] == 2 diff --git a/tests/unit/orchestrators/test_step_launcher_params.py b/tests/unit/orchestrators/test_step_launcher_params.py new file mode 100644 index 00000000000..74d259198ac --- /dev/null +++ b/tests/unit/orchestrators/test_step_launcher_params.py @@ -0,0 +1,49 @@ +"""Unit tests for StepLauncher request parameter validation/merge. + +These tests verify allowlisting, simple type coercion, and size caps when +merging request parameters into the effective step configuration in serving. +""" + +from zenml.orchestrators.step_launcher import StepLauncher + + +def test_validate_and_merge_request_params_allowlist_and_types(monkeypatch): + """Allowlist known params and coerce simple types; drop unknowns.""" + # Use the real method by binding to a StepLauncher instance with minimal init + sl = StepLauncher.__new__(StepLauncher) # type: ignore + + class Cfg: + def __init__(self): + self.parameters = {"city": "paris", "count": 1} + + effective = Cfg() + req = { + "city": "munich", # allowed, string + "count": "2", # allowed, coercible to int + "unknown": "drop-me", # not declared + } + + merged = StepLauncher._validate_and_merge_request_params( + sl, req, effective + ) + assert merged["city"] == "munich" + assert merged["count"] == 2 + assert "unknown" not in merged + + +def test_validate_and_merge_request_params_size_caps(monkeypatch): + """Drop oversized string/collection parameters per safety caps.""" + sl = StepLauncher.__new__(StepLauncher) # type: ignore + + class Cfg: + def __init__(self): + self.parameters = {"text": "ok"} + + effective = Cfg() + big = "x" * 20000 # 20KB string -> dropped + req = {"text": big} + merged = StepLauncher._validate_and_merge_request_params( + sl, req, effective + ) + # Should keep the default, drop oversize + assert merged["text"] == "ok"