Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
STATUS = "/status"
STEP_CONFIGURATION = "/step-configuration"
STEPS = "/steps"
HEARTBEAT = "heartbeat"
STOP = "/stop"
TAGS = "/tags"
TAG_RESOURCES = "/tag_resources"
Expand Down
4 changes: 3 additions & 1 deletion src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@
StepRunResponse,
StepRunResponseBody,
StepRunResponseMetadata,
StepRunResponseResources
StepRunResponseResources,
StepHeartbeatResponse,
)
from zenml.models.v2.core.tag import (
TagFilter,
Expand Down Expand Up @@ -874,4 +875,5 @@
"ProjectStatistics",
"PipelineRunDAG",
"ExceptionInfo",
"StepHeartbeatResponse",
]
26 changes: 25 additions & 1 deletion src/zenml/models/v2/core/step_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from uuid import UUID

from pydantic import ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field

from zenml.config.step_configurations import StepConfiguration, StepSpec
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
Expand Down Expand Up @@ -200,6 +200,10 @@ class StepRunResponseBody(ProjectScopedResponseBody):
title="The end time of the step run.",
default=None,
)
latest_heartbeat: Optional[datetime] = Field(
title="The latest heartbeat of the step run.",
default=None,
)
model_version_id: Optional[UUID] = Field(
title="The ID of the model version that was "
"configured by this step run explicitly.",
Expand Down Expand Up @@ -565,6 +569,15 @@ def end_time(self) -> Optional[datetime]:
"""
return self.get_body().end_time

@property
def latest_heartbeat(self) -> Optional[datetime]:
"""The `latest_heartbeat` property.

Returns:
the value of the property.
"""
return self.get_body().latest_heartbeat

@property
def logs(self) -> Optional["LogsResponse"]:
"""The `logs` property.
Expand Down Expand Up @@ -747,3 +760,14 @@ def get_custom_filters(
)

return custom_filters


# ------------------ Heartbeat Model ---------------


class StepHeartbeatResponse(BaseModel):
"""Light-weight model for Step Heartbeat responses."""

id: UUID
status: str
latest_heartbeat: datetime
178 changes: 178 additions & 0 deletions src/zenml/steps/heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""ZenML Step HeartBeat functionality."""

import _thread
import logging
import threading
import time
from typing import Annotated
from uuid import UUID

from pydantic import BaseModel, conint, model_validator

from zenml.enums import ExecutionStatus

logger = logging.getLogger(__name__)


class StepHeartBeatTerminationException(Exception):
"""Custom exception class for heartbeat termination."""

pass


class StepHeartBeatOptions(BaseModel):
"""Options group for step heartbeat execution."""

step_id: UUID
interval: Annotated[int, conint(ge=10, le=60)]
name: str | None = None

@model_validator(mode="after")
def set_default_name(self) -> "StepHeartBeatOptions":
"""Model validator - set name value if missing.

Returns:
The validated step heartbeat options.
"""
if not self.name:
self.name = f"HeartBeatWorker-{self.step_id}"

return self


class HeartbeatWorker:
"""Worker class implementing heartbeat polling and remote termination."""

def __init__(self, options: StepHeartBeatOptions):
"""Heartbeat worker constructor.

Args:
options: Parameter group - polling interval, step id, etc.
"""
self.options = options

self._thread: threading.Thread | None = None
self._running: bool = False
self._terminated: bool = (
False # one-shot guard to avoid repeated interrupts
)

# properties

@property
def interval(self) -> int:
"""Property function for heartbeat interval.

Returns:
The heartbeat polling interval value.
"""
return self.options.interval

@property
def name(self) -> str:
"""Property function for heartbeat worker name.

Returns:
The name of the heartbeat worker.
"""
return str(self.options.name)

@property
def step_id(self) -> UUID:
"""Property function for heartbeat worker step ID.

Returns:
The id of the step heartbeat is running for.
"""
return self.options.step_id

# public functions

def start(self) -> None:
"""Start the heartbeat worker on a background thread."""
if self._thread and self._thread.is_alive():
logger.info("%s already running; start() is a no-op", self.name)
return

self._running = True
self._terminated = False
self._thread = threading.Thread(
target=self._run, name=self.name, daemon=True
)
self._thread.start()
logger.info(
"Daemon thread %s started (interval=%s)", self.name, self.interval
)

def stop(self) -> None:
"""Stops the heartbeat worker."""
if not self._running:
return
self._running = False
logger.info("%s stop requested", self.name)

def is_alive(self) -> bool:
"""Liveness of the heartbeat worker thread.

Returns:
True if the heartbeat worker thread is alive, False otherwise.
"""
t = self._thread
return bool(t and t.is_alive())

def _run(self) -> None:
logger.info("%s run() loop entered", self.name)
try:
while self._running:
try:
self._heartbeat()
except StepHeartBeatTerminationException:
# One-shot: signal the main thread and stop the loop.
if not self._terminated:
self._terminated = True
logger.info(
"%s received HeartBeatTerminationException; "
"interrupting main thread",
self.name,
)
_thread.interrupt_main() # raises KeyboardInterrupt in main thread
# Ensure we stop our own loop as well.
self._running = False
except Exception:
Copy link
Contributor Author

@Json-Andriopoulos Json-Andriopoulos Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Improve this. For sure try to capture HTTP errors in more verbose logs to avoid excessive log generation if the error is for instance server raising 500 status code.

# Log-and-continue policy for all other errors.
logger.exception(
"%s heartbeat() failed; continuing", self.name
)
# Sleep after each attempt (even after errors, unless stopped).
if self._running:
time.sleep(self.interval)
finally:
logger.info("%s run() loop exiting", self.name)

def _heartbeat(self) -> None:
from zenml.config.global_config import GlobalConfiguration

store = GlobalConfiguration().zen_store

response = store.update_step_heartbeat(step_run_id=self.step_id)

if response.status in {
ExecutionStatus.STOPPED,
ExecutionStatus.STOPPING,
}:
raise StepHeartBeatTerminationException(
f"Step {self.step_id} remotely stopped with status {response.status}."
)
26 changes: 26 additions & 0 deletions src/zenml/zen_server/routers/steps_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from zenml.constants import (
API,
HEARTBEAT,
LOGS,
STATUS,
STEP_CONFIGURATION,
Expand All @@ -38,6 +39,7 @@
StepRunResponse,
StepRunUpdate,
)
from zenml.models.v2.core.step_run import StepHeartbeatResponse
from zenml.zen_server.auth import (
AuthContext,
authorize,
Expand Down Expand Up @@ -200,6 +202,30 @@ def update_step(
return dehydrate_response_model(updated_step)


@router.put(
"/{step_run_id}/" + HEARTBEAT,
responses={401: error_response, 404: error_response, 422: error_response},
)
@async_fastapi_endpoint_wrapper(deduplicate=True)
def update_heartbeat(
step_run_id: UUID,
_: AuthContext = Security(authorize),
) -> StepHeartbeatResponse:
"""Updates a step.

Args:
step_run_id: ID of the step.

Returns:
The step heartbeat response (id, status, last_heartbeat).
"""
step = zen_store().get_run_step(step_run_id, hydrate=True)
pipeline_run = zen_store().get_run(step.pipeline_run_id)
verify_permission_for_model(pipeline_run, action=Action.UPDATE)

return zen_store().update_step_heartbeat(step_run_id=step_run_id)


@router.get(
"/{step_id}" + STEP_CONFIGURATION,
responses={401: error_response, 404: error_response, 422: error_response},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Add heartbeat column for step runs [a5a17015b681].

Revision ID: a5a17015b681
Revises: 0.90.0
Create Date: 2025-10-13 12:24:12.470803

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "a5a17015b681"
down_revision = "0.90.0"
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
with op.batch_alter_table("step_run", schema=None) as batch_op:
batch_op.add_column(
sa.Column("latest_heartbeat", sa.DateTime(), nullable=True)
)


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
with op.batch_alter_table("step_run", schema=None) as batch_op:
batch_op.drop_column("latest_heartbeat")
19 changes: 19 additions & 0 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING,
EVENT_SOURCES,
FLAVORS,
HEARTBEAT,
INFO,
LOGIN,
LOGS,
Expand Down Expand Up @@ -254,6 +255,7 @@
StackRequest,
StackResponse,
StackUpdate,
StepHeartbeatResponse,
StepRunFilter,
StepRunRequest,
StepRunResponse,
Expand Down Expand Up @@ -3303,6 +3305,23 @@ def update_run_step(
route=STEPS,
)

def update_step_heartbeat(
self, step_run_id: UUID
) -> StepHeartbeatResponse:
"""Updates a step run heartbeat.

Args:
step_run_id: The ID of the step to update.

Returns:
The step heartbeat response.
"""
response_body = self.put(
f"{STEPS}/{str(step_run_id)}/{HEARTBEAT}", body=None, params=None
)

return StepHeartbeatResponse.model_validate(response_body)

# -------------------- Triggers --------------------

def create_trigger(self, trigger: TriggerRequest) -> TriggerResponse:
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/zen_stores/schemas/step_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
# Fields
start_time: Optional[datetime] = Field(nullable=True)
end_time: Optional[datetime] = Field(nullable=True)
latest_heartbeat: Optional[datetime] = Field(
nullable=True,
description="The latest execution heartbeat.",
)
status: str = Field(nullable=False)

docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True))
Expand Down
Loading