diff --git a/docs/book/how-to/artifacts/visualizations.md b/docs/book/how-to/artifacts/visualizations.md index 01cfd41afd..8934414508 100644 --- a/docs/book/how-to/artifacts/visualizations.md +++ b/docs/book/how-to/artifacts/visualizations.md @@ -65,6 +65,183 @@ There are three ways how you can add custom visualizations to the dashboard: * If you are already handling HTML, Markdown, CSV or JSON data in one of your steps, you can have them visualized in just a few lines of code by casting them to a [special class](#visualization-via-special-return-types) inside your step. * If you want to automatically extract visualizations for all artifacts of a certain data type, you can define type-specific visualization logic by [building a custom materializer](#visualization-via-materializers). +### Curated Visualizations Across Resources + +Curated visualizations let you surface a specific artifact visualization across multiple ZenML resources. Each curated visualization links to exactly one resource—for example, a model performance report that appears on the model detail page, or a deployment health dashboard that shows up in the deployment view. + +Curated visualizations currently support the following resources: + +- **Projects** – high-level dashboards and KPIs that summarize the state of a project. +- **Deployments** – monitoring pages for deployed pipelines. +- **Models** – evaluation dashboards and health views for registered models. +- **Pipelines** – reusable visual documentation attached to pipeline definitions. +- **Pipeline Runs** – detailed diagnostics for specific executions. +- **Pipeline Snapshots** – configuration/version comparisons for snapshot history. + +You can create a curated visualization programmatically by linking an artifact visualization to a single resource. Provide the resource identifier and resource type directly when creating the visualization. The example below shows how to create separate visualizations for different resource types: + +```python +from uuid import UUID + +from zenml.client import Client +from zenml.enums import ( + CuratedVisualizationSize, + VisualizationResourceTypes, +) + +client = Client() + +# Define the identifiers for the pipeline and run you want to enrich +pipeline_id = UUID("") +pipeline_run_id = UUID("") + +# Retrieve the artifact version produced by the evaluation step +pipeline_run = client.get_pipeline_run(pipeline_run_id) +artifact_version_id = pipeline_run.output.get("evaluation_report") +artifact_version = client.get_artifact_version(artifact_version_id) +artifact_visualizations = artifact_version.visualizations or [] + +# Fetch the resources we want to enrich +model = client.list_models().items[0] +model_id = model.id + +deployment = client.list_deployments().items[0] +deployment_id = deployment.id + +project_id = client.active_project.id + +pipeline_model = client.get_pipeline(pipeline_id) +pipeline_id = pipeline_model.id + +pipeline_snapshot = pipeline_run.snapshot() +snapshot_id = pipeline_snapshot.id + +pipeline_run_id = pipeline_run.id + +# Create curated visualizations for each supported resource type +client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[0].id, + resource_id=model_id, + resource_type=VisualizationResourceTypes.MODEL, + project_id=project_id, + display_name="Latest Model Evaluation", +) + +client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[1].id, + resource_id=deployment_id, + resource_type=VisualizationResourceTypes.DEPLOYMENT, + project_id=project_id, + display_name="Deployment Health Dashboard", +) + +client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[2].id, + resource_id=project_id, + resource_type=VisualizationResourceTypes.PROJECT, + display_name="Project Overview", +) + +client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[3].id, + resource_id=pipeline_id, + resource_type=VisualizationResourceTypes.PIPELINE, + project_id=project_id, + display_name="Pipeline Summary", +) + +client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[4].id, + resource_id=pipeline_run_id, + resource_type=VisualizationResourceTypes.PIPELINE_RUN, + project_id=project_id, + display_name="Run Results", +) + +client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[5].id, + resource_id=snapshot_id, + resource_type=VisualizationResourceTypes.PIPELINE_SNAPSHOT, + project_id=project_id, + display_name="Snapshot Metrics", +) +``` + +After creation, the returned response includes the visualization ID. You can retrieve a specific visualization later with `Client.get_curated_visualization`: + +```python +retrieved = client.get_curated_visualization(pipeline_viz.id, hydrate=True) +print(retrieved.display_name) +print(retrieved.resource.type) +print(retrieved.resource.id) +``` + +Curated visualizations are tied to their parent resources and automatically surface in the ZenML dashboard wherever those resources appear, so keep track of the IDs returned by `create_curated_visualization` if you need to reference them later. + +#### Updating curated visualizations + +Once you've created a curated visualization, you can update its display name, order, or tile size using `Client.update_curated_visualization`: + +```python +from uuid import UUID + +client.update_curated_visualization( + visualization_id=UUID(""), + display_name="Updated Dashboard Title", + display_order=10, + layout_size=CuratedVisualizationSize.HALF_WIDTH, +) +``` + +When a visualization is no longer relevant, you can remove it entirely: + +```python +client.delete_curated_visualization(visualization_id=UUID("")) +``` + +#### Controlling display order and size + +The optional `display_order` field determines how visualizations are sorted when displayed. Visualizations with lower order values appear first, while those with `None` (the default) appear at the end in creation order. + +When setting display orders, consider leaving gaps between values (e.g., 10, 20, 30 instead of 1, 2, 3) to make it easier to insert new visualizations later without renumbering everything: + +```python +# Leave gaps for future insertions +visualization_a = client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[0].id, + resource_type=VisualizationResourceTypes.PIPELINE, + resource_id=pipeline_id, + display_name="Model performance at a glance", + display_order=10, # Primary dashboard + layout_size=CuratedVisualizationSize.HALF_WIDTH, +) + +visualization_b = client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[1].id, + resource_type=VisualizationResourceTypes.PIPELINE, + resource_id=pipeline_id, + display_name="Drill-down metrics", + display_order=20, # Secondary metrics + layout_size=CuratedVisualizationSize.HALF_WIDTH, # Compact chart beside the primary tile +) + +# Later, easily insert between them +visualization_c = client.create_curated_visualization( + artifact_visualization_id=artifact_visualizations[2].id, + resource_type=VisualizationResourceTypes.PIPELINE, + resource_id=pipeline_id, + display_name="Raw output preview", + display_order=15, # Now appears between A and B + layout_size=CuratedVisualizationSize.FULL_WIDTH, +) +``` + +#### RBAC visibility + +Curated visualizations respect the access permissions of the resource they're linked to. A user can only see a curated visualization if they have read access to the specific resource it targets. If a user lacks permission for the linked resource, the visualization will be hidden from their view. + +For example, if you create a visualization linked to a specific deployment, only users with read access to that deployment will see the visualization. If you need the same visualization to appear in different contexts with different access controls (e.g., on both a project page and a deployment page), create separate curated visualizations for each resource. This ensures that visualizations never inadvertently expose information from resources a user shouldn't access, while giving you fine-grained control over visibility. + ### Visualization via Special Return Types If you already have HTML, Markdown, CSV or JSON data available as a string inside your step, you can simply cast them to one of the following types and return them from your step: @@ -257,4 +434,4 @@ steps: Visualizing artifacts is a powerful way to gain insights from your ML pipelines. ZenML's built-in visualization capabilities make it easy to understand your data and model outputs, identify issues, and communicate results. -By leveraging these visualization tools, you can better understand your ML workflows, debug problems more effectively, and make more informed decisions about your models. \ No newline at end of file +By leveraging these visualization tools, you can better understand your ML workflows, debug problems more effectively, and make more informed decisions about your models. diff --git a/src/zenml/client.py b/src/zenml/client.py index 09204ba883..3972004d96 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -61,6 +61,7 @@ from zenml.enums import ( ArtifactType, ColorVariants, + CuratedVisualizationSize, DeploymentStatus, LogicalOperators, ModelStages, @@ -72,6 +73,7 @@ StackComponentType, StoreType, TaggableResourceTypes, + VisualizationResourceTypes, ) from zenml.exceptions import ( AuthorizationException, @@ -108,6 +110,9 @@ ComponentRequest, ComponentResponse, ComponentUpdate, + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationUpdate, DeploymentFilter, DeploymentResponse, EventSourceFilter, @@ -3739,6 +3744,96 @@ def get_deployment( hydrate=hydrate, ) + def create_curated_visualization( + self, + artifact_visualization_id: UUID, + *, + resource_id: UUID, + resource_type: VisualizationResourceTypes, + project_id: Optional[UUID] = None, + display_name: Optional[str] = None, + display_order: Optional[int] = None, + layout_size: CuratedVisualizationSize = CuratedVisualizationSize.FULL_WIDTH, + ) -> CuratedVisualizationResponse: + """Create a curated visualization associated with a resource. + + Curated visualizations can be attached to any of the following + ZenML resource types to provide contextual dashboards throughout the ML + lifecycle: + + - **Deployments** (VisualizationResourceTypes.DEPLOYMENT): Surface on + deployment monitoring dashboards + - **Pipelines** (VisualizationResourceTypes.PIPELINE): Associate with + pipeline definitions + - **Pipeline Runs** (VisualizationResourceTypes.PIPELINE_RUN): Attach to + specific execution runs + - **Pipeline Snapshots** (VisualizationResourceTypes.PIPELINE_SNAPSHOT): + Link to captured pipeline configurations + + Each visualization is linked to exactly one resource. + + Args: + artifact_visualization_id: The UUID of the artifact visualization to curate. + resource_id: The identifier of the resource tied to the visualization. + resource_type: The type of resource referenced by the visualization. + project_id: The ID of the project to associate with the visualization. + display_name: The display name of the visualization. + display_order: The display order of the visualization. + layout_size: The layout size of the visualization in the dashboard. + + Returns: + The created curated visualization. + """ + request = CuratedVisualizationRequest( + project=project_id or self.active_project.id, + artifact_visualization_id=artifact_visualization_id, + display_name=display_name, + display_order=display_order, + layout_size=layout_size, + resource_id=resource_id, + resource_type=resource_type, + ) + return self.zen_store.create_curated_visualization(request) + + def update_curated_visualization( + self, + visualization_id: UUID, + *, + display_name: Optional[str] = None, + display_order: Optional[int] = None, + layout_size: Optional[CuratedVisualizationSize] = None, + ) -> CuratedVisualizationResponse: + """Update display metadata for a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to update. + display_name: New display name for the visualization. + display_order: New display order for the visualization. + layout_size: Updated layout size for the visualization. + + Returns: + The updated deployment visualization. + """ + update_model = CuratedVisualizationUpdate( + display_name=display_name, + display_order=display_order, + layout_size=layout_size, + ) + return self.zen_store.update_curated_visualization( + visualization_id=visualization_id, + visualization_update=update_model, + ) + + def delete_curated_visualization(self, visualization_id: UUID) -> None: + """Delete a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to delete. + """ + self.zen_store.delete_curated_visualization( + visualization_id=visualization_id + ) + def list_deployments( self, sort_by: str = "created", diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 0145115b37..456f66d60e 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -403,6 +403,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: PIPELINE_CONFIGURATION = "/pipeline-configuration" PIPELINE_DEPLOYMENTS = "/pipeline_deployments" DEPLOYMENTS = "/deployments" +CURATED_VISUALIZATIONS = "/curated_visualizations" PIPELINE_SNAPSHOTS = "/pipeline_snapshots" PIPELINES = "/pipelines" PIPELINE_SPEC = "/pipeline-spec" diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 93edbba278..85aae8b6f4 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -418,6 +418,41 @@ class MetadataResourceTypes(StrEnum): SCHEDULE = "schedule" +class VisualizationResourceTypes(StrEnum): + """Resource types that support curated visualizations. + + Curated visualizations can be attached to these ZenML resources to provide + contextual dashboards and visual insights throughout the ML lifecycle: + + - **DEPLOYMENT**: Server-side pipeline deployments - surface visualizations + on deployment monitoring dashboards and status pages + - **MODEL**: ZenML model entities - surface model evaluation dashboards and + performance summaries directly on the model detail pages + - **PIPELINE**: Pipeline definitions - associate visualizations with pipeline + configurations for reusable visual documentation + - **PIPELINE_RUN**: Pipeline execution runs - attach visualizations to specific + run results for detailed analysis and debugging + - **PIPELINE_SNAPSHOT**: Pipeline snapshots - link visualizations to captured + pipeline configurations for version comparison and historical analysis + - **PROJECT**: Project-level overviews - provide high-level project dashboards + and KPI visualizations for cross-pipeline insights + """ + + DEPLOYMENT = "deployment" # Server-side pipeline deployments + MODEL = "model" # ZenML models + PIPELINE = "pipeline" # Pipeline definitions + PIPELINE_RUN = "pipeline_run" # Execution runs + PIPELINE_SNAPSHOT = "pipeline_snapshot" # Snapshot configurations + PROJECT = "project" # Project-level dashboards + + +class CuratedVisualizationSize(StrEnum): + """Layout size options for curated visualizations.""" + + FULL_WIDTH = "full_width" + HALF_WIDTH = "half_width" + + class SecretResourceTypes(StrEnum): """All possible resource types for adding secrets.""" diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 43c81edf51..ef0737dfa8 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -112,6 +112,7 @@ ArtifactVisualizationResponse, ArtifactVisualizationResponseBody, ArtifactVisualizationResponseMetadata, + ArtifactVisualizationResponseResources, ) from zenml.models.v2.core.service import ( ServiceResponse, @@ -164,6 +165,14 @@ DeploymentResponseMetadata, DeploymentResponseResources, ) +from zenml.models.v2.core.curated_visualization import ( + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationResponseBody, + CuratedVisualizationResponseMetadata, + CuratedVisualizationResponseResources, + CuratedVisualizationUpdate, +) from zenml.models.v2.core.device import ( OAuthDeviceUpdate, OAuthDeviceFilter, @@ -473,6 +482,10 @@ ArtifactVersionResponseBody.model_rebuild() ArtifactVersionResponseMetadata.model_rebuild() ArtifactVersionResponseResources.model_rebuild() +ArtifactVisualizationResponse.model_rebuild() +ArtifactVisualizationResponseBody.model_rebuild() +ArtifactVisualizationResponseMetadata.model_rebuild() +ArtifactVisualizationResponseResources.model_rebuild() CodeReferenceResponseBody.model_rebuild() CodeRepositoryResponseBody.model_rebuild() CodeRepositoryResponseMetadata.model_rebuild() @@ -484,6 +497,10 @@ DeploymentResponseBody.model_rebuild() DeploymentResponseMetadata.model_rebuild() DeploymentResponseResources.model_rebuild() +CuratedVisualizationResponseBody.model_rebuild() +CuratedVisualizationResponseMetadata.model_rebuild() +CuratedVisualizationResponseResources.model_rebuild() +CuratedVisualizationResponse.model_rebuild() EventSourceResponseBody.model_rebuild() EventSourceResponseMetadata.model_rebuild() EventSourceResponseResources.model_rebuild() @@ -630,6 +647,7 @@ "ArtifactVisualizationResponse", "ArtifactVisualizationResponseBody", "ArtifactVisualizationResponseMetadata", + "ArtifactVisualizationResponseResources", "CodeReferenceRequest", "CodeReferenceResponse", "CodeReferenceResponseBody", @@ -659,6 +677,12 @@ "DeploymentResponseBody", "DeploymentResponseMetadata", "DeploymentResponseResources", + "CuratedVisualizationRequest", + "CuratedVisualizationResponse", + "CuratedVisualizationResponseBody", + "CuratedVisualizationResponseMetadata", + "CuratedVisualizationResponseResources", + "CuratedVisualizationUpdate", "EventSourceFlavorResponse", "EventSourceFlavorResponseBody", "EventSourceFlavorResponseMetadata", diff --git a/src/zenml/models/v2/core/artifact_visualization.py b/src/zenml/models/v2/core/artifact_visualization.py index 8dde68741c..5326877c0e 100644 --- a/src/zenml/models/v2/core/artifact_visualization.py +++ b/src/zenml/models/v2/core/artifact_visualization.py @@ -13,8 +13,11 @@ # permissions and limitations under the License. """Models representing artifact visualizations.""" +from typing import TYPE_CHECKING, Optional from uuid import UUID +from pydantic import Field + from zenml.enums import VisualizationType from zenml.models.v2.base.base import ( BaseDatedResponseBody, @@ -24,6 +27,9 @@ BaseResponseResources, ) +if TYPE_CHECKING: + from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + # ------------------ Request Model ------------------ @@ -57,6 +63,12 @@ class ArtifactVisualizationResponseMetadata(BaseResponseMetadata): class ArtifactVisualizationResponseResources(BaseResponseResources): """Class for all resource models associated with the artifact visualization.""" + artifact_version: Optional["ArtifactVersionResponse"] = Field( + default=None, + title="The artifact version.", + description="Artifact version that owns this visualization, when included.", + ) + class ArtifactVisualizationResponse( BaseIdentifiedResponse[ @@ -105,6 +117,15 @@ def artifact_version_id(self) -> UUID: """ return self.get_metadata().artifact_version_id + @property + def artifact_version(self) -> Optional["ArtifactVersionResponse"]: + """The artifact version resource, if the response was hydrated with it. + + Returns: + The artifact version resource, if available. + """ + return self.get_resources().artifact_version + # ------------------ Filter Model ------------------ diff --git a/src/zenml/models/v2/core/curated_visualization.py b/src/zenml/models/v2/core/curated_visualization.py new file mode 100644 index 0000000000..6da41b4eaa --- /dev/null +++ b/src/zenml/models/v2/core/curated_visualization.py @@ -0,0 +1,270 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Models representing curated visualizations.""" + +from typing import TYPE_CHECKING, Optional +from uuid import UUID + +from pydantic import Field, NonNegativeInt + +from zenml.enums import CuratedVisualizationSize, VisualizationResourceTypes +from zenml.models.v2.base.base import BaseUpdate +from zenml.models.v2.base.scoped import ( + ProjectScopedRequest, + ProjectScopedResponse, + ProjectScopedResponseBody, + ProjectScopedResponseMetadata, + ProjectScopedResponseResources, +) + +if TYPE_CHECKING: + from zenml.models.v2.core.artifact_visualization import ( + ArtifactVisualizationResponse, + ) + + +# ------------------ Request Model ------------------ + + +class CuratedVisualizationRequest(ProjectScopedRequest): + """Request model for curated visualizations. + + Each curated visualization links a pre-rendered artifact visualization + to a single ZenML resource to surface it in the appropriate UI context. + Supported resources include: + - **Deployments** + - **Models** + - **Pipelines** + - **Pipeline Runs** + - **Pipeline Snapshots** + - **Projects** + """ + + artifact_visualization_id: UUID = Field( + title="The artifact visualization ID.", + description=( + "Identifier of the artifact visualization that should be surfaced " + "for the target resource." + ), + ) + display_name: Optional[str] = Field( + default=None, + title="The display name of the visualization.", + ) + display_order: Optional[NonNegativeInt] = Field( + default=None, + title="The display order of the visualization.", + description=( + "Optional ordering hint that must be unique for the combination " + "of resource type and resource ID." + ), + ) + layout_size: CuratedVisualizationSize = Field( + default=CuratedVisualizationSize.FULL_WIDTH, + title="The layout size of the visualization.", + description=( + "Controls how much horizontal space the visualization occupies " + "on the dashboard." + ), + ) + resource_id: UUID = Field( + title="The linked resource ID.", + description=( + "Identifier of the resource (deployment, model, pipeline, pipeline " + "run, pipeline snapshot, or project) that should surface this " + "visualization." + ), + ) + resource_type: VisualizationResourceTypes = Field( + title="The linked resource type.", + description="Type of the resource associated with this visualization.", + ) + + +# ------------------ Update Model ------------------ + + +class CuratedVisualizationUpdate(BaseUpdate): + """Update model for curated visualizations.""" + + display_name: Optional[str] = Field( + default=None, + title="The new display name of the visualization.", + ) + display_order: Optional[NonNegativeInt] = Field( + default=None, + title="The new display order of the visualization.", + description=( + "Optional ordering hint. When provided, it must remain unique for " + "the combination of resource type and resource ID." + ), + ) + layout_size: Optional[CuratedVisualizationSize] = Field( + default=None, + title="The updated layout size of the visualization.", + ) + + +# ------------------ Response Model ------------------ + + +class CuratedVisualizationResponseBody(ProjectScopedResponseBody): + """Response body for curated visualizations.""" + + artifact_visualization_id: UUID = Field( + title="The artifact visualization ID.", + description=( + "Identifier of the artifact visualization that is curated for this resource." + ), + ) + artifact_version_id: UUID = Field( + title="The artifact version ID.", + description=( + "Identifier of the artifact version that owns the curated visualization. " + "Provided for read-only context when available." + ), + ) + display_name: Optional[str] = Field( + default=None, + title="The display name of the visualization.", + ) + display_order: Optional[NonNegativeInt] = Field( + default=None, + title="The display order of the visualization.", + description=( + "Optional ordering hint that is unique per combination of " + "resource type and resource ID." + ), + ) + layout_size: CuratedVisualizationSize = Field( + default=CuratedVisualizationSize.FULL_WIDTH, + title="The layout size of the visualization.", + ) + resource_id: UUID = Field( + title="The linked resource ID.", + description="Identifier of the resource associated with this visualization.", + ) + resource_type: VisualizationResourceTypes = Field( + title="The linked resource type.", + description="Type of the resource associated with this visualization.", + ) + + +class CuratedVisualizationResponseMetadata(ProjectScopedResponseMetadata): + """Response metadata for curated visualizations.""" + + +class CuratedVisualizationResponseResources(ProjectScopedResponseResources): + """Response resources included for curated visualizations.""" + + artifact_visualization: "ArtifactVisualizationResponse" = Field( + title="The artifact visualization.", + description=( + "Artifact visualization that is surfaced through this curated visualization." + ), + ) + + +class CuratedVisualizationResponse( + ProjectScopedResponse[ + CuratedVisualizationResponseBody, + CuratedVisualizationResponseMetadata, + CuratedVisualizationResponseResources, + ] +): + """Response model for curated visualizations.""" + + def get_hydrated_version(self) -> "CuratedVisualizationResponse": + """Get the hydrated version of this curated visualization. + + Returns: + A hydrated instance of the same entity. + """ + from zenml.client import Client + + client = Client() + return client.zen_store.get_curated_visualization(self.id) + + # Helper properties + @property + def artifact_visualization_id(self) -> UUID: + """The artifact visualization ID. + + Returns: + The artifact visualization ID. + """ + return self.get_body().artifact_visualization_id + + @property + def artifact_version_id(self) -> UUID: + """The artifact version ID. + + Returns: + The artifact version ID if available. + """ + return self.get_body().artifact_version_id + + @property + def display_name(self) -> Optional[str]: + """The display name of the visualization. + + Returns: + The display name of the visualization. + """ + return self.get_body().display_name + + @property + def display_order(self) -> Optional[int]: + """The display order of the visualization. + + Returns: + The display order of the visualization. + """ + return self.get_body().display_order + + @property + def layout_size(self) -> CuratedVisualizationSize: + """The layout size of the visualization. + + Returns: + The layout size of the visualization. + """ + return self.get_body().layout_size + + @property + def artifact_visualization(self) -> "ArtifactVisualizationResponse": + """The curated artifact visualization resource. + + Returns: + The artifact visualization resource. + """ + return self.get_resources().artifact_visualization + + @property + def resource_id(self) -> UUID: + """The identifier of the linked resource. + + Returns: + The resource identifier associated with this visualization. + """ + return self.get_body().resource_id + + @property + def resource_type(self) -> VisualizationResourceTypes: + """The type of the linked resource. + + Returns: + The resource type associated with this visualization. + """ + return self.get_body().resource_type diff --git a/src/zenml/models/v2/core/deployment.py b/src/zenml/models/v2/core/deployment.py index 4130961bf7..d08cd30ed5 100644 --- a/src/zenml/models/v2/core/deployment.py +++ b/src/zenml/models/v2/core/deployment.py @@ -46,6 +46,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.component import ComponentResponse + from zenml.models.v2.core.curated_visualization import ( + CuratedVisualizationResponse, + ) from zenml.models.v2.core.pipeline import PipelineResponse from zenml.models.v2.core.pipeline_snapshot import ( PipelineSnapshotResponse, @@ -205,6 +208,10 @@ class DeploymentResponseResources(ProjectScopedResponseResources): tags: List["TagResponse"] = Field( title="Tags associated with the deployment.", ) + visualizations: List["CuratedVisualizationResponse"] = Field( + default_factory=list, + title="Curated deployment visualizations.", + ) class DeploymentResponse( @@ -305,6 +312,15 @@ def tags(self) -> List["TagResponse"]: """ return self.get_resources().tags + @property + def visualizations(self) -> List["CuratedVisualizationResponse"]: + """The visualizations of the deployment. + + Returns: + The visualizations of the deployment. + """ + return self.get_resources().visualizations + @property def snapshot_id(self) -> Optional[UUID]: """The pipeline snapshot ID. diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index eaffa6a68d..68ff7f9582 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -45,6 +45,9 @@ if TYPE_CHECKING: from zenml.model.model import Model + from zenml.models.v2.core.curated_visualization import ( + CuratedVisualizationResponse, + ) from zenml.models.v2.core.tag import TagResponse from zenml.zen_stores.schemas import BaseSchema @@ -185,6 +188,10 @@ class ModelResponseResources(ProjectScopedResponseResources): ) latest_version_name: Optional[str] = None latest_version_id: Optional[UUID] = None + visualizations: List["CuratedVisualizationResponse"] = Field( + default_factory=list, + title="Curated visualizations associated with the model.", + ) class ModelResponse( @@ -309,6 +316,15 @@ def save_models_to_registry(self) -> bool: """ return self.get_metadata().save_models_to_registry + @property + def visualizations(self) -> List["CuratedVisualizationResponse"]: + """The `visualizations` property. + + Returns: + the value of the property. + """ + return self.get_resources().visualizations + # Helper functions @property def versions(self) -> List["Model"]: diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 8a5823f067..6122a07c97 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -46,7 +46,11 @@ from zenml.models.v2.core.tag import TagResponse if TYPE_CHECKING: - from zenml.models import PipelineRunResponse, UserResponse + from zenml.models import ( + CuratedVisualizationResponse, + PipelineRunResponse, + UserResponse, + ) from zenml.zen_stores.schemas import BaseSchema AnySchema = TypeVar("AnySchema", bound=BaseSchema) @@ -127,6 +131,10 @@ class PipelineResponseResources(ProjectScopedResponseResources): tags: List[TagResponse] = Field( title="Tags associated with the pipeline.", ) + visualizations: List["CuratedVisualizationResponse"] = Field( + default=[], + title="Curated visualizations associated with the pipeline.", + ) class PipelineResponse( diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 5da1d75fd0..2405819744 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -56,6 +56,9 @@ from zenml.models import TriggerExecutionResponse from zenml.models.v2.core.artifact_version import ArtifactVersionResponse from zenml.models.v2.core.code_reference import CodeReferenceResponse + from zenml.models.v2.core.curated_visualization import ( + CuratedVisualizationResponse, + ) from zenml.models.v2.core.logs import LogsResponse from zenml.models.v2.core.pipeline import PipelineResponse from zenml.models.v2.core.pipeline_build import ( @@ -305,6 +308,10 @@ class PipelineRunResponseResources(ProjectScopedResponseResources): title="Logs associated with this pipeline run.", default=None, ) + visualizations: List["CuratedVisualizationResponse"] = Field( + default=[], + title="Curated visualizations associated with the pipeline run.", + ) # TODO: In Pydantic v2, the `model_` is a protected namespaces for all # fields defined under base models. If not handled, this raises a warning. diff --git a/src/zenml/models/v2/core/pipeline_snapshot.py b/src/zenml/models/v2/core/pipeline_snapshot.py index 4298d6d49e..c56cea969d 100644 --- a/src/zenml/models/v2/core/pipeline_snapshot.py +++ b/src/zenml/models/v2/core/pipeline_snapshot.py @@ -61,6 +61,9 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement + from zenml.models.v2.core.curated_visualization import ( + CuratedVisualizationResponse, + ) from zenml.zen_stores.schemas.base_schemas import BaseSchema AnySchema = TypeVar("AnySchema", bound=BaseSchema) @@ -335,6 +338,10 @@ class PipelineSnapshotResponseResources(ProjectScopedResponseResources): default=None, title="The user that created the latest run of the snapshot.", ) + visualizations: List["CuratedVisualizationResponse"] = Field( + default=[], + title="Curated visualizations associated with the pipeline snapshot.", + ) class PipelineSnapshotResponse( diff --git a/src/zenml/zen_server/routers/curated_visualization_endpoints.py b/src/zenml/zen_server/routers/curated_visualization_endpoints.py new file mode 100644 index 0000000000..b7f3858dfc --- /dev/null +++ b/src/zenml/zen_server/routers/curated_visualization_endpoints.py @@ -0,0 +1,190 @@ +"""REST API endpoints for curated visualizations.""" + +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, Security + +from zenml.constants import API, CURATED_VISUALIZATIONS, VERSION_1 +from zenml.enums import VisualizationResourceTypes +from zenml.models import ( + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationUpdate, +) +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.models import Action +from zenml.zen_server.rbac.utils import verify_permission_for_model +from zenml.zen_server.utils import async_fastapi_endpoint_wrapper, zen_store + +router = APIRouter( + prefix=API + VERSION_1 + CURATED_VISUALIZATIONS, + tags=["curated_visualizations"], + responses={401: error_response, 404: error_response, 422: error_response}, +) + + +def _get_resource_model( + resource_type: VisualizationResourceTypes, + resource_id: UUID, +) -> Any: + """Fetch the concrete resource model for a curated visualization. + + Args: + resource_type: The type of resource linked to the curated visualization. + resource_id: The unique identifier of the linked resource. + + Returns: + The hydrated resource model retrieved from the Zen store. + + Raises: + RuntimeError: If the provided resource type is not supported. + """ + store = zen_store() + + if resource_type == VisualizationResourceTypes.DEPLOYMENT: + return store.get_deployment(resource_id) + if resource_type == VisualizationResourceTypes.MODEL: + return store.get_model(resource_id) + if resource_type == VisualizationResourceTypes.PIPELINE: + return store.get_pipeline(resource_id) + if resource_type == VisualizationResourceTypes.PIPELINE_RUN: + return store.get_run(resource_id) + if resource_type == VisualizationResourceTypes.PIPELINE_SNAPSHOT: + return store.get_snapshot(resource_id) + if resource_type == VisualizationResourceTypes.PROJECT: + return store.get_project(resource_id) + + raise RuntimeError( + f"Unsupported curated visualization resource type: {resource_type}" + ) + + +@router.post( + "", + responses={ + 401: error_response, + 404: error_response, + 409: error_response, + 422: error_response, + }, +) +@async_fastapi_endpoint_wrapper +def create_curated_visualization( + visualization: CuratedVisualizationRequest, + _: AuthContext = Security(authorize), +) -> CuratedVisualizationResponse: + """Create a curated visualization. + + Args: + visualization: The curated visualization to create. + + Returns: + The created curated visualization. + """ + store = zen_store() + resource_model = _get_resource_model( + visualization.resource_type, visualization.resource_id + ) + artifact_visualization = store.get_artifact_visualization( + visualization.artifact_visualization_id + ) + + verify_permission_for_model(resource_model, action=Action.UPDATE) + verify_permission_for_model(artifact_visualization, action=Action.READ) + + return store.create_curated_visualization(visualization) + + +@router.get( + "/{visualization_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def get_curated_visualization( + visualization_id: UUID, + hydrate: bool = True, + _: AuthContext = Security(authorize), +) -> CuratedVisualizationResponse: + """Retrieve a curated visualization by ID. + + Args: + visualization_id: The ID of the curated visualization to retrieve. + hydrate: Flag deciding whether to return the hydrated model. + + Returns: + The curated visualization with the given ID. + """ + store = zen_store() + hydrated_visualization = store.get_curated_visualization( + visualization_id, hydrate=hydrate + ) + resource_type = hydrated_visualization.resource_type + resource_id = hydrated_visualization.resource_id + + resource_model = _get_resource_model(resource_type, resource_id) + verify_permission_for_model(resource_model, action=Action.READ) + + return hydrated_visualization + + +@router.put( + "/{visualization_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def update_curated_visualization( + visualization_id: UUID, + visualization_update: CuratedVisualizationUpdate, + _: AuthContext = Security(authorize), +) -> CuratedVisualizationResponse: + """Update a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to update. + visualization_update: The update to apply to the curated visualization. + + Returns: + The updated curated visualization. + """ + store = zen_store() + existing_visualization = store.get_curated_visualization( + visualization_id, hydrate=True + ) + resource_type = existing_visualization.resource_type + resource_id = existing_visualization.resource_id + + resource_model = _get_resource_model(resource_type, resource_id) + verify_permission_for_model(resource_model, action=Action.UPDATE) + + return store.update_curated_visualization( + visualization_id, visualization_update + ) + + +@router.delete( + "/{visualization_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper +def delete_curated_visualization( + visualization_id: UUID, + _: AuthContext = Security(authorize), +) -> None: + """Delete a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to delete. + """ + store = zen_store() + existing_visualization = store.get_curated_visualization( + visualization_id, hydrate=True + ) + resource_type = existing_visualization.resource_type + resource_id = existing_visualization.resource_id + + resource_model = _get_resource_model(resource_type, resource_id) + verify_permission_for_model(resource_model, action=Action.UPDATE) + + store.delete_curated_visualization(visualization_id) diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 5071998276..46feb51c47 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -57,6 +57,7 @@ artifact_version_endpoints, auth_endpoints, code_repositories_endpoints, + curated_visualization_endpoints, deployment_endpoints, devices_endpoints, event_source_endpoints, @@ -265,6 +266,7 @@ async def dashboard(request: Request) -> Any: app.include_router(devices_endpoints.router) app.include_router(code_repositories_endpoints.router) app.include_router(deployment_endpoints.router) +app.include_router(curated_visualization_endpoints.router) app.include_router(plugin_endpoints.plugin_router) app.include_router(event_source_endpoints.event_source_router) app.include_router(flavors_endpoints.router) diff --git a/src/zenml/zen_stores/migrations/versions/24552f3be1f2_add_visualisations.py b/src/zenml/zen_stores/migrations/versions/24552f3be1f2_add_visualisations.py new file mode 100644 index 0000000000..071bdd50a1 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/24552f3be1f2_add_visualisations.py @@ -0,0 +1,73 @@ +"""add visualisations [24552f3be1f2]. + +Revision ID: 24552f3be1f2 +Revises: 124b57b8c7b1 +Create Date: 2025-10-24 07:56:57.575675 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "24552f3be1f2" +down_revision = "124b57b8c7b1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "curated_visualization", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("project_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "artifact_visualization_id", + sqlmodel.sql.sqltypes.GUID(), + nullable=False, + ), + sa.Column( + "display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + sa.Column("display_order", sa.Integer(), nullable=True), + sa.Column( + "layout_size", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "resource_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.ForeignKeyConstraint( + ["artifact_visualization_id"], + ["artifact_visualization.id"], + name="fk_curated_visualization_artifact_visualization_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + name="fk_curated_visualization_project_id_project", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint( + "artifact_visualization_id", + "resource_id", + "resource_type", + name="unique_curated_visualization_resource_link", + ), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("curated_visualization") + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index fdde0b4130..fd9bcd35c5 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -65,6 +65,7 @@ CODE_REFERENCES, CODE_REPOSITORIES, CONFIG, + CURATED_VISUALIZATIONS, CURRENT_USER, DEACTIVATE, DEFAULT_HTTP_TIMEOUT, @@ -165,6 +166,9 @@ ComponentRequest, ComponentResponse, ComponentUpdate, + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationUpdate, DeployedStack, DeploymentFilter, DeploymentRequest, @@ -1856,6 +1860,77 @@ def delete_deployment(self, deployment_id: UUID) -> None: route=DEPLOYMENTS, ) + def create_curated_visualization( + self, visualization: CuratedVisualizationRequest + ) -> CuratedVisualizationResponse: + """Create a curated visualization via REST API. + + Args: + visualization: The curated visualization to create. + + Returns: + The created curated visualization. + """ + return self._create_resource( + resource=visualization, + response_model=CuratedVisualizationResponse, + route=CURATED_VISUALIZATIONS, + params={"hydrate": True}, + ) + + def get_curated_visualization( + self, visualization_id: UUID, hydrate: bool = True + ) -> CuratedVisualizationResponse: + """Get a curated visualization by ID. + + Args: + visualization_id: The ID of the curated visualization to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The curated visualization with the given ID. + """ + return self._get_resource( + resource_id=visualization_id, + route=CURATED_VISUALIZATIONS, + response_model=CuratedVisualizationResponse, + params={"hydrate": hydrate}, + ) + + def update_curated_visualization( + self, + visualization_id: UUID, + visualization_update: CuratedVisualizationUpdate, + ) -> CuratedVisualizationResponse: + """Update a curated visualization via REST API. + + Args: + visualization_id: The ID of the curated visualization to update. + visualization_update: The update to apply to the curated + visualization. + + Returns: + The updated curated visualization. + """ + return self._update_resource( + resource_id=visualization_id, + resource_update=visualization_update, + response_model=CuratedVisualizationResponse, + route=CURATED_VISUALIZATIONS, + ) + + def delete_curated_visualization(self, visualization_id: UUID) -> None: + """Delete a curated visualization via REST API. + + Args: + visualization_id: The ID of the curated visualization to delete. + """ + self._delete_resource( + resource_id=visualization_id, + route=CURATED_VISUALIZATIONS, + ) + # -------------------- Run templates -------------------- def create_run_template( diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index b98adfcfea..24a3c88c8f 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -31,6 +31,9 @@ from zenml.zen_stores.schemas.event_source_schemas import EventSourceSchema from zenml.zen_stores.schemas.pipeline_build_schemas import PipelineBuildSchema from zenml.zen_stores.schemas.deployment_schemas import DeploymentSchema +from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, +) from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.flavor_schemas import FlavorSchema from zenml.zen_stores.schemas.server_settings_schemas import ServerSettingsSchema @@ -88,6 +91,7 @@ "CodeReferenceSchema", "CodeRepositorySchema", "DeploymentSchema", + "CuratedVisualizationSchema", "EventSourceSchema", "FlavorSchema", "LogsSchema", diff --git a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py index 79862fc037..d7e1576f71 100644 --- a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """SQLModel implementation of artifact visualization table.""" -from typing import Any +from typing import TYPE_CHECKING, Any, List from uuid import UUID from sqlalchemy import TEXT, Column @@ -25,11 +25,17 @@ ArtifactVisualizationResponse, ArtifactVisualizationResponseBody, ArtifactVisualizationResponseMetadata, + ArtifactVisualizationResponseResources, ) from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.base_schemas import BaseSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +if TYPE_CHECKING: + from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, + ) + class ArtifactVisualizationSchema(BaseSchema, table=True): """SQL Model for visualizations of artifacts.""" @@ -54,6 +60,13 @@ class ArtifactVisualizationSchema(BaseSchema, table=True): artifact_version: ArtifactVersionSchema = Relationship( back_populates="visualizations" ) + curated_visualizations: List["CuratedVisualizationSchema"] = Relationship( + back_populates="artifact_visualization", + sa_relationship_kwargs=dict( + order_by="CuratedVisualizationSchema.display_order", + cascade="delete", + ), + ) @classmethod def from_model( @@ -107,8 +120,22 @@ def to_model( artifact_version_id=self.artifact_version_id, ) + resources = None + if include_resources: + if self.artifact_version is not None: + artifact_version = self.artifact_version.to_model( + include_metadata=False, + include_resources=False, + ) + else: + artifact_version = None + resources = ArtifactVisualizationResponseResources( + artifact_version=artifact_version, + ) + return ArtifactVisualizationResponse( id=self.id, body=body, metadata=metadata, + resources=resources, ) diff --git a/src/zenml/zen_stores/schemas/curated_visualization_schemas.py b/src/zenml/zen_stores/schemas/curated_visualization_schemas.py new file mode 100644 index 0000000000..e2490fc971 --- /dev/null +++ b/src/zenml/zen_stores/schemas/curated_visualization_schemas.py @@ -0,0 +1,224 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""SQLModel implementation of curated visualization tables.""" + +from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from uuid import UUID + +from sqlalchemy import UniqueConstraint +from sqlalchemy.orm import selectinload +from sqlalchemy.sql.base import ExecutableOption +from sqlmodel import Field, Relationship + +from zenml.enums import CuratedVisualizationSize, VisualizationResourceTypes +from zenml.models.v2.core.curated_visualization import ( + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationResponseBody, + CuratedVisualizationResponseMetadata, + CuratedVisualizationResponseResources, + CuratedVisualizationUpdate, +) +from zenml.zen_stores.schemas.base_schemas import BaseSchema +from zenml.zen_stores.schemas.project_schemas import ProjectSchema +from zenml.zen_stores.schemas.schema_utils import ( + build_foreign_key_field, +) +from zenml.zen_stores.schemas.utils import jl_arg + +if TYPE_CHECKING: + from zenml.zen_stores.schemas.artifact_visualization_schemas import ( + ArtifactVisualizationSchema, + ) + + +class CuratedVisualizationSchema(BaseSchema, table=True): + """SQL Model for curated visualizations.""" + + __tablename__ = "curated_visualization" + __table_args__ = ( + UniqueConstraint( + "artifact_visualization_id", + "resource_id", + "resource_type", + name="unique_curated_visualization_resource_link", + ), + ) + + project_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ProjectSchema.__tablename__, + source_column="project_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + artifact_visualization_id: UUID = build_foreign_key_field( + source=__tablename__, + target="artifact_visualization", + source_column="artifact_visualization_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + custom_constraint_name="fk_curated_visualization_artifact_visualization_id", + ) + + display_name: Optional[str] = Field(default=None) + display_order: Optional[int] = Field(default=None) + layout_size: str = Field( + default=CuratedVisualizationSize.FULL_WIDTH.value, + nullable=False, + ) + resource_id: UUID = Field(nullable=False) + resource_type: str = Field(nullable=False) + + artifact_visualization: "ArtifactVisualizationSchema" = Relationship( + back_populates="curated_visualizations" + ) + + @classmethod + def get_query_options( + cls, + include_metadata: bool = False, + include_resources: bool = False, + **kwargs: Any, + ) -> Sequence[ExecutableOption]: + """Get the query options for the schema. + + Args: + include_metadata: Whether metadata will be included when converting + the schema to a model. + include_resources: Whether resources will be included when + converting the schema to a model. + **kwargs: Keyword arguments to allow schema specific logic + + Returns: + A list of query options. + """ + options: List[ExecutableOption] = [] + + if include_resources: + options.append(selectinload(jl_arg(cls.artifact_visualization))) + + return options + + @classmethod + def from_request( + cls, request: CuratedVisualizationRequest + ) -> "CuratedVisualizationSchema": + """Convert a request into a schema instance. + + Args: + request: The request to convert. + + Returns: + The created schema. + """ + return cls( + project_id=request.project, + artifact_visualization_id=request.artifact_visualization_id, + display_name=request.display_name, + display_order=request.display_order, + layout_size=request.layout_size.value, + resource_id=request.resource_id, + resource_type=request.resource_type.value, + ) + + def update( + self, + update: CuratedVisualizationUpdate, + ) -> "CuratedVisualizationSchema": + """Update a schema instance from an update model. + + Args: + update: The update definition. + + Returns: + The updated schema. + """ + changes = update.model_dump(exclude_unset=True) + layout_size_update = changes.pop("layout_size", None) + if layout_size_update is not None: + self.layout_size = layout_size_update.value + + for field, value in changes.items(): + if hasattr(self, field): + setattr(self, field, value) + + from zenml.utils.time_utils import utc_now + + self.updated = utc_now() + return self + + def to_model( + self, + include_metadata: bool = False, + include_resources: bool = False, + **kwargs: Any, + ) -> CuratedVisualizationResponse: + """Convert schema into response model. + + Args: + include_metadata: Whether to include metadata in the response. + include_resources: Whether to include resources in the response. + **kwargs: Additional keyword arguments. + + Returns: + The created response model. + """ + try: + layout_size_enum = CuratedVisualizationSize(self.layout_size) + except ValueError: + layout_size_enum = CuratedVisualizationSize.FULL_WIDTH + + try: + resource_type_enum = VisualizationResourceTypes(self.resource_type) + except ValueError: + resource_type_enum = VisualizationResourceTypes.PROJECT + + artifact_version_id = self.artifact_visualization.artifact_version_id + + body = CuratedVisualizationResponseBody( + project_id=self.project_id, + created=self.created, + updated=self.updated, + artifact_visualization_id=self.artifact_visualization_id, + artifact_version_id=artifact_version_id, + display_name=self.display_name, + display_order=self.display_order, + layout_size=layout_size_enum, + resource_id=self.resource_id, + resource_type=resource_type_enum, + ) + + metadata = None + if include_metadata: + metadata = CuratedVisualizationResponseMetadata() + + resources = None + if include_resources: + artifact_visualization = self.artifact_visualization.to_model( + include_metadata=False, + include_resources=False, + ) + resources = CuratedVisualizationResponseResources( + artifact_visualization=artifact_visualization, + ) + + return CuratedVisualizationResponse( + id=self.id, + body=body, + metadata=metadata, + resources=resources, + ) diff --git a/src/zenml/zen_stores/schemas/deployment_schemas.py b/src/zenml/zen_stores/schemas/deployment_schemas.py index 9a42825fbf..98d7327e57 100644 --- a/src/zenml/zen_stores/schemas/deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/deployment_schemas.py @@ -24,7 +24,11 @@ from sqlmodel import Field, Relationship, String from zenml.constants import MEDIUMTEXT_MAX_LENGTH -from zenml.enums import DeploymentStatus, TaggableResourceTypes +from zenml.enums import ( + DeploymentStatus, + TaggableResourceTypes, + VisualizationResourceTypes, +) from zenml.logger import get_logger from zenml.models.v2.core.deployment import ( DeploymentRequest, @@ -46,6 +50,9 @@ from zenml.zen_stores.schemas.utils import jl_arg if TYPE_CHECKING: + from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, + ) from zenml.zen_stores.schemas.tag_schemas import TagSchema logger = get_logger(__name__) @@ -133,6 +140,19 @@ class DeploymentSchema(NamedSchema, table=True): ), ) + visualizations: List["CuratedVisualizationSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=( + "and_(CuratedVisualizationSchema.resource_type" + f"=='{VisualizationResourceTypes.DEPLOYMENT.value}', " + "foreign(CuratedVisualizationSchema.resource_id)==DeploymentSchema.id)" + ), + overlaps="visualizations", + cascade="delete", + order_by="CuratedVisualizationSchema.display_order", + ), + ) + @classmethod def get_query_options( cls, @@ -162,6 +182,7 @@ def get_query_options( selectinload(jl_arg(DeploymentSchema.snapshot)).joinedload( jl_arg(PipelineSnapshotSchema.pipeline) ), + selectinload(jl_arg(DeploymentSchema.visualizations)), ] ) @@ -220,6 +241,13 @@ def to_model( pipeline=self.snapshot.pipeline.to_model() if self.snapshot and self.snapshot.pipeline else None, + visualizations=[ + visualization.to_model( + include_metadata=False, + include_resources=False, + ) + for visualization in self.visualizations + ], ) return DeploymentResponse( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 1d77d926a3..96856f1cdb 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -24,13 +24,14 @@ Column, UniqueConstraint, ) -from sqlalchemy.orm import joinedload, object_session +from sqlalchemy.orm import joinedload, object_session, selectinload from sqlalchemy.sql.base import ExecutableOption from sqlmodel import Field, Relationship, desc, select from zenml.enums import ( MetadataResourceTypes, TaggableResourceTypes, + VisualizationResourceTypes, ) from zenml.models import ( BaseResponseMetadata, @@ -57,6 +58,9 @@ from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME +from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, +) from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.project_schemas import ProjectSchema from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema @@ -128,6 +132,18 @@ class ModelSchema(NamedSchema, table=True): back_populates="model", sa_relationship_kwargs={"cascade": "delete"}, ) + visualizations: List["CuratedVisualizationSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=( + "and_(CuratedVisualizationSchema.resource_type" + f"=='{VisualizationResourceTypes.MODEL.value}', " + "foreign(CuratedVisualizationSchema.resource_id)==ModelSchema.id)" + ), + overlaps="visualizations", + cascade="delete", + order_by="CuratedVisualizationSchema.display_order", + ), + ) @classmethod def get_query_options( @@ -155,6 +171,7 @@ def get_query_options( [ joinedload(jl_arg(ModelSchema.user)), # joinedload(jl_arg(ModelSchema.tags)), + selectinload(jl_arg(ModelSchema.visualizations)), ] ) @@ -254,6 +271,13 @@ def to_model( tags=[tag.to_model() for tag in self.tags], latest_version_name=latest_version_name, latest_version_id=latest_version_id, + visualizations=[ + visualization.to_model( + include_metadata=False, + include_resources=False, + ) + for visualization in self.visualizations + ], ) body = ModelResponseBody( diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 1a192d7976..115d78764c 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -34,6 +34,7 @@ MetadataResourceTypes, PipelineRunTriggeredByType, TaggableResourceTypes, + VisualizationResourceTypes, ) from zenml.logger import get_logger from zenml.models import ( @@ -72,6 +73,9 @@ ) if TYPE_CHECKING: + from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, + ) from zenml.zen_stores.schemas.logs_schemas import LogsSchema from zenml.zen_stores.schemas.model_schemas import ( ModelVersionPipelineRunSchema, @@ -241,6 +245,18 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): overlaps="tags", ), ) + visualizations: List["CuratedVisualizationSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=( + "and_(CuratedVisualizationSchema.resource_type" + f"=='{VisualizationResourceTypes.PIPELINE_RUN.value}', " + "foreign(CuratedVisualizationSchema.resource_id)==PipelineRunSchema.id)" + ), + overlaps="visualizations", + cascade="delete", + order_by="CuratedVisualizationSchema.display_order", + ), + ) # Needed for cascade deletion model_versions_pipeline_runs_links: List[ @@ -315,6 +331,7 @@ def get_query_options( selectinload(jl_arg(PipelineRunSchema.logs)), selectinload(jl_arg(PipelineRunSchema.user)), selectinload(jl_arg(PipelineRunSchema.tags)), + selectinload(jl_arg(PipelineRunSchema.visualizations)), ] ) @@ -632,6 +649,13 @@ def to_model( tags=[tag.to_model() for tag in self.tags], logs=client_logs[0].to_model() if client_logs else None, log_collection=[log.to_model() for log in self.logs], + visualizations=[ + visualization.to_model( + include_metadata=False, + include_resources=False, + ) + for visualization in self.visualizations + ], ) return PipelineRunResponse( @@ -819,7 +843,7 @@ def _check_if_run_in_progress(self) -> bool: else: in_progress = any( not ExecutionStatus(status).is_finished - for name, status in step_run_statuses + for _, status in step_run_statuses ) return in_progress else: diff --git a/src/zenml/zen_stores/schemas/pipeline_schemas.py b/src/zenml/zen_stores/schemas/pipeline_schemas.py index ea4e3a7359..cf0923d641 100644 --- a/src/zenml/zen_stores/schemas/pipeline_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_schemas.py @@ -17,11 +17,11 @@ from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint -from sqlalchemy.orm import joinedload, object_session +from sqlalchemy.orm import joinedload, object_session, selectinload from sqlalchemy.sql.base import ExecutableOption from sqlmodel import Field, Relationship, desc, select -from zenml.enums import TaggableResourceTypes +from zenml.enums import TaggableResourceTypes, VisualizationResourceTypes from zenml.models import ( PipelineRequest, PipelineResponse, @@ -38,6 +38,9 @@ from zenml.zen_stores.schemas.utils import jl_arg if TYPE_CHECKING: + from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, + ) from zenml.zen_stores.schemas.pipeline_build_schemas import ( PipelineBuildSchema, ) @@ -104,6 +107,18 @@ class PipelineSchema(NamedSchema, table=True): overlaps="tags", ), ) + visualizations: List["CuratedVisualizationSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=( + "and_(CuratedVisualizationSchema.resource_type" + f"=='{VisualizationResourceTypes.PIPELINE.value}', " + "foreign(CuratedVisualizationSchema.resource_id)==PipelineSchema.id)" + ), + overlaps="visualizations", + cascade="delete", + order_by="CuratedVisualizationSchema.display_order", + ), + ) @property def latest_run(self) -> Optional["PipelineRunSchema"]: @@ -159,6 +174,7 @@ def get_query_options( [ joinedload(jl_arg(PipelineSchema.user)), # joinedload(jl_arg(PipelineSchema.tags)), + selectinload(jl_arg(PipelineSchema.visualizations)), ] ) @@ -226,6 +242,13 @@ def to_model( latest_run_id=latest_run.id if latest_run else None, latest_run_status=latest_run.status if latest_run else None, tags=[tag.to_model() for tag in self.tags], + visualizations=[ + visualization.to_model( + include_metadata=False, + include_resources=False, + ) + for visualization in self.visualizations + ], ) return PipelineResponse( diff --git a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py index 290dcec937..668b0ef04e 100644 --- a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py @@ -19,7 +19,7 @@ from sqlalchemy import TEXT, Column, String, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT -from sqlalchemy.orm import joinedload, object_session +from sqlalchemy.orm import joinedload, object_session, selectinload from sqlalchemy.sql.base import ExecutableOption from sqlmodel import Field, Relationship, asc, col, desc, select @@ -27,7 +27,7 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.config.step_configurations import Step from zenml.constants import MEDIUMTEXT_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH -from zenml.enums import TaggableResourceTypes +from zenml.enums import TaggableResourceTypes, VisualizationResourceTypes from zenml.logger import get_logger from zenml.models import ( PipelineSnapshotRequest, @@ -53,6 +53,9 @@ from zenml.zen_stores.schemas.utils import jl_arg if TYPE_CHECKING: + from zenml.zen_stores.schemas.curated_visualization_schemas import ( + CuratedVisualizationSchema, + ) from zenml.zen_stores.schemas.deployment_schemas import ( DeploymentSchema, ) @@ -218,6 +221,18 @@ class PipelineSnapshotSchema(BaseSchema, table=True): overlaps="tags", ), ) + visualizations: List["CuratedVisualizationSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=( + "and_(CuratedVisualizationSchema.resource_type" + f"=='{VisualizationResourceTypes.PIPELINE_SNAPSHOT.value}', " + "foreign(CuratedVisualizationSchema.resource_id)==PipelineSnapshotSchema.id)" + ), + overlaps="visualizations", + cascade="delete", + order_by="CuratedVisualizationSchema.display_order", + ), + ) @property def latest_run(self) -> Optional["PipelineRunSchema"]: @@ -352,7 +367,14 @@ def get_query_options( ) if include_resources: - options.extend([joinedload(jl_arg(PipelineSnapshotSchema.user))]) + options.extend( + [ + joinedload(jl_arg(PipelineSnapshotSchema.user)), + selectinload( + jl_arg(PipelineSnapshotSchema.visualizations) + ), + ] + ) return options @@ -565,6 +587,13 @@ def to_model( latest_run_user=latest_run_user.to_model() if latest_run_user else None, + visualizations=[ + visualization.to_model( + include_metadata=False, + include_resources=False, + ) + for visualization in self.visualizations + ], ) return PipelineSnapshotResponse( diff --git a/src/zenml/zen_stores/schemas/project_schemas.py b/src/zenml/zen_stores/schemas/project_schemas.py index e639ba57c2..a4fe796e1a 100644 --- a/src/zenml/zen_stores/schemas/project_schemas.py +++ b/src/zenml/zen_stores/schemas/project_schemas.py @@ -18,6 +18,7 @@ from sqlalchemy import UniqueConstraint from sqlmodel import Relationship +from zenml.enums import VisualizationResourceTypes from zenml.models import ( ProjectRequest, ProjectResponse, @@ -33,6 +34,7 @@ ActionSchema, ArtifactVersionSchema, CodeRepositorySchema, + CuratedVisualizationSchema, DeploymentSchema, EventSourceSchema, ModelSchema, @@ -127,6 +129,18 @@ class ProjectSchema(NamedSchema, table=True): back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) + visualizations: List["CuratedVisualizationSchema"] = Relationship( + sa_relationship_kwargs=dict( + primaryjoin=( + "and_(CuratedVisualizationSchema.resource_type" + f"=='{VisualizationResourceTypes.PROJECT.value}', " + "foreign(CuratedVisualizationSchema.resource_id)==ProjectSchema.id)" + ), + overlaps="visualizations", + cascade="delete", + order_by="CuratedVisualizationSchema.display_order", + ), + ) @classmethod def from_request(cls, project: ProjectRequest) -> "ProjectSchema": diff --git a/src/zenml/zen_stores/schemas/schema_utils.py b/src/zenml/zen_stores/schemas/schema_utils.py index 4925685442..c179f5543d 100644 --- a/src/zenml/zen_stores/schemas/schema_utils.py +++ b/src/zenml/zen_stores/schemas/schema_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions for SQLModel schemas.""" -from typing import Any, List +from typing import Any, List, Optional from sqlalchemy import Column, ForeignKey, Index from sqlmodel import Field @@ -45,6 +45,7 @@ def build_foreign_key_field( target_column: str, ondelete: str, nullable: bool, + custom_constraint_name: Optional[str] = None, **sa_column_kwargs: Any, ) -> Any: """Build a SQLModel foreign key field. @@ -56,6 +57,7 @@ def build_foreign_key_field( target_column: Target column name. ondelete: On delete behavior. nullable: Whether the field is nullable. + custom_constraint_name: Custom name for the foreign key constraint. **sa_column_kwargs: Keyword arguments for the SQLAlchemy column. Returns: @@ -63,16 +65,22 @@ def build_foreign_key_field( Raises: ValueError: If the ondelete and nullable arguments are not compatible. + ValueError: If the foreign key constraint name is too long. """ if not nullable and ondelete == "SET NULL": raise ValueError( "Cannot set ondelete to SET NULL if the field is not nullable." ) - constraint_name = foreign_key_constraint_name( + constraint_name = custom_constraint_name or foreign_key_constraint_name( source=source, target=target, source_column=source_column, ) + if len(constraint_name) > 64: + raise ValueError( + f"Foreign key constraint name {constraint_name} is too long. " + "The maximum length is 64 characters." + ) return Field( sa_column=Column( ForeignKey( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 3cf0af8b6d..adcd124f4d 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -69,13 +69,18 @@ field_validator, model_validator, ) -from sqlalchemy import QueuePool, func, update +from sqlalchemy import QueuePool, event, func, update from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( ArgumentError, IntegrityError, ) -from sqlalchemy.orm import Mapped, load_only, noload, selectinload +from sqlalchemy.orm import ( + Mapped, + load_only, + noload, + selectinload, +) from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.util import immutabledict from sqlmodel import Session as SqlModelSession @@ -148,6 +153,7 @@ StepRunInputArtifactType, StoreType, TaggableResourceTypes, + VisualizationResourceTypes, ) from zenml.exceptions import ( AuthorizationException, @@ -199,6 +205,9 @@ ComponentRequest, ComponentResponse, ComponentUpdate, + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationUpdate, DefaultComponentRequest, DefaultStackRequest, DeployedStack, @@ -360,6 +369,7 @@ BaseSchema, CodeReferenceSchema, CodeRepositorySchema, + CuratedVisualizationSchema, DeploymentSchema, EventSourceSchema, FlavorSchema, @@ -1280,6 +1290,19 @@ def _initialize(self) -> None: ): self.migrate_database() + if self.config.driver == SQLDatabaseDriver.SQLITE: + # Enable foreign key checks at the SQLite database level, but only + # after any migration has been done. + @event.listens_for(self._engine, "connect") + def _(dbapi_connection: Any, connection_record: Any) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + # Discard existing connections created without the foreign key + # checks enabled + self._engine.dispose() + secrets_store_config = self.config.secrets_store # Initialize the secrets store @@ -5394,6 +5417,275 @@ def delete_deployment(self, deployment_id: UUID) -> None: session.delete(deployment) session.commit() + # -------------------- Curated visualizations -------------------- + + def _assert_curated_visualization_duplicate( + self, + session: Session, + *, + artifact_visualization_id: UUID, + resource_id: UUID, + resource_type: VisualizationResourceTypes, + ) -> None: + """Ensure a curated visualization link does not already exist. + + Args: + session: The database session. + artifact_visualization_id: The ID of the artifact visualization. + resource_id: The ID of the resource. + resource_type: The type of the resource. + + Raises: + EntityExistsError: If a curated visualization link already exists. + """ + existing = session.exec( + select(CuratedVisualizationSchema) + .where( + CuratedVisualizationSchema.artifact_visualization_id + == artifact_visualization_id + ) + .where(CuratedVisualizationSchema.resource_id == resource_id) + .where( + CuratedVisualizationSchema.resource_type == resource_type.value + ) + ).first() + if existing is not None: + raise EntityExistsError( + "A curated visualization for this resource already exists " + "for the specified artifact visualization." + ) + + def _assert_curated_visualization_display_order_unique( + self, + session: Session, + *, + resource_id: UUID, + resource_type: VisualizationResourceTypes, + display_order: Optional[int], + exclude_visualization_id: Optional[UUID] = None, + ) -> None: + """Ensure curated visualizations per resource use unique display orders. + + Args: + session: The database session. + resource_id: The ID of the resource. + resource_type: The type of the resource. + display_order: The display order to check. + exclude_visualization_id: The ID of the visualization to exclude. + + Raises: + EntityExistsError: If a curated visualization for this resource already uses the display order. + """ + if display_order is None: + return + + statement = ( + select(CuratedVisualizationSchema) + .where(CuratedVisualizationSchema.resource_id == resource_id) + .where( + CuratedVisualizationSchema.resource_type == resource_type.value + ) + .where(CuratedVisualizationSchema.display_order == display_order) + ) + if exclude_visualization_id is not None: + statement = statement.where( + CuratedVisualizationSchema.id != exclude_visualization_id + ) + + existing = session.exec(statement).first() + if existing is not None: + raise EntityExistsError( + "A curated visualization for this resource already uses the " + f"display order '{display_order}'. Please choose a different value." + ) + + def create_curated_visualization( + self, visualization: CuratedVisualizationRequest + ) -> CuratedVisualizationResponse: + """Persist a curated visualization link. + + Args: + visualization: The curated visualization to create. + + Returns: + The created curated visualization. + + Raises: + IllegalOperationError: If the curated visualization does not target the same project as the artifact visualization. + ValueError: If the resource type is invalid. + KeyError: If the resource is not found. + """ + with Session(self.engine) as session: + self._set_request_user_id( + request_model=visualization, session=session + ) + + artifact_visualization: ArtifactVisualizationSchema = ( + self._get_reference_schema_by_id( + resource=visualization, + reference_schema=ArtifactVisualizationSchema, + reference_id=visualization.artifact_visualization_id, + session=session, + ) + ) + + artifact_version = artifact_visualization.artifact_version + project_id = artifact_version.project_id + + if visualization.project != project_id: + raise IllegalOperationError( + "Curated visualizations must target the same project as " + "the artifact visualization." + ) + project_id = visualization.project + + resource_schema_map: Dict[ + VisualizationResourceTypes, Type[BaseSchema] + ] = { + VisualizationResourceTypes.DEPLOYMENT: DeploymentSchema, + VisualizationResourceTypes.MODEL: ModelSchema, + VisualizationResourceTypes.PIPELINE: PipelineSchema, + VisualizationResourceTypes.PIPELINE_RUN: PipelineRunSchema, + VisualizationResourceTypes.PIPELINE_SNAPSHOT: PipelineSnapshotSchema, + VisualizationResourceTypes.PROJECT: ProjectSchema, + } + + if visualization.resource_type not in resource_schema_map: + raise ValueError( + f"Invalid resource type: {visualization.resource_type}" + ) + + schema_class = resource_schema_map[visualization.resource_type] + resource_schema = session.exec( + select(schema_class).where( + schema_class.id == visualization.resource_id + ) + ).first() + + if not resource_schema: + raise KeyError( + f"Resource of type '{visualization.resource_type.value}' " + f"with ID {visualization.resource_id} not found." + ) + + if hasattr(resource_schema, "project_id"): + resource_project_id = resource_schema.project_id + if resource_project_id and resource_project_id != project_id: + raise IllegalOperationError( + f"Resource {visualization.resource_type.value} with ID " + f"{visualization.resource_id} belongs to a different project than " + f"the curated visualization (project ID: {project_id})." + ) + + self._assert_curated_visualization_duplicate( + session=session, + artifact_visualization_id=visualization.artifact_visualization_id, + resource_id=visualization.resource_id, + resource_type=visualization.resource_type, + ) + if visualization.display_order is not None: + self._assert_curated_visualization_display_order_unique( + session=session, + resource_id=visualization.resource_id, + resource_type=visualization.resource_type, + display_order=visualization.display_order, + ) + + schema = CuratedVisualizationSchema.from_request(visualization) + + session.add(schema) + session.commit() + session.refresh(schema) + + return schema.to_model( + include_metadata=True, + include_resources=True, + ) + + def get_curated_visualization( + self, visualization_id: UUID, hydrate: bool = True + ) -> CuratedVisualizationResponse: + """Fetch a curated visualization by ID. + + Args: + visualization_id: The ID of the curated visualization to fetch. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The curated visualization with the given ID. + """ + with Session(self.engine) as session: + schema: CuratedVisualizationSchema = self._get_schema_by_id( + resource_id=visualization_id, + schema_class=CuratedVisualizationSchema, + session=session, + ) + return schema.to_model( + include_metadata=hydrate, + include_resources=hydrate, + ) + + def update_curated_visualization( + self, + visualization_id: UUID, + visualization_update: CuratedVisualizationUpdate, + ) -> CuratedVisualizationResponse: + """Update mutable fields on a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to update. + visualization_update: The update to apply to the curated visualization. + + Returns: + The updated curated visualization. + """ + with Session(self.engine) as session: + schema = self._get_schema_by_id( + resource_id=visualization_id, + schema_class=CuratedVisualizationSchema, + session=session, + ) + update_fields = visualization_update.model_dump(exclude_unset=True) + if "display_order" in update_fields: + new_display_order = update_fields["display_order"] + if new_display_order is not None: + self._assert_curated_visualization_display_order_unique( + session=session, + resource_id=schema.resource_id, + resource_type=VisualizationResourceTypes( + schema.resource_type + ), + display_order=new_display_order, + exclude_visualization_id=visualization_id, + ) + # Explicit None clears the display order, so uniqueness validation is skipped. + + schema.update(visualization_update) + session.add(schema) + session.commit() + session.refresh(schema) + + return schema.to_model( + include_metadata=True, + include_resources=True, + ) + + def delete_curated_visualization(self, visualization_id: UUID) -> None: + """Delete a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to delete. + """ + with Session(self.engine) as session: + schema = self._get_schema_by_id( + resource_id=visualization_id, + schema_class=CuratedVisualizationSchema, + session=session, + ) + session.delete(schema) + session.commit() + # -------------------- Run templates -------------------- @track_decorator(AnalyticsEvent.CREATED_RUN_TEMPLATE) diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 8848a24b5a..350a74c038 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -48,6 +48,9 @@ ComponentRequest, ComponentResponse, ComponentUpdate, + CuratedVisualizationRequest, + CuratedVisualizationResponse, + CuratedVisualizationUpdate, DeployedStack, DeploymentFilter, DeploymentRequest, @@ -1471,6 +1474,60 @@ def delete_deployment(self, deployment_id: UUID) -> None: KeyError: If the deployment does not exist. """ + # -------------------- Curated visualizations -------------------- + + @abstractmethod + def create_curated_visualization( + self, visualization: CuratedVisualizationRequest + ) -> CuratedVisualizationResponse: + """Create a new curated visualization. + + Args: + visualization: The curated visualization to create. + + Returns: + The created curated visualization. + """ + + @abstractmethod + def get_curated_visualization( + self, visualization_id: UUID, hydrate: bool = True + ) -> CuratedVisualizationResponse: + """Get a curated visualization by ID. + + Args: + visualization_id: The ID of the curated visualization to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The curated visualization with the given ID. + """ + + @abstractmethod + def update_curated_visualization( + self, + visualization_id: UUID, + visualization_update: CuratedVisualizationUpdate, + ) -> CuratedVisualizationResponse: + """Update a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to update. + visualization_update: The update to apply to the curated visualization. + + Returns: + The updated curated visualization. + """ + + @abstractmethod + def delete_curated_visualization(self, visualization_id: UUID) -> None: + """Delete a curated visualization. + + Args: + visualization_id: The ID of the curated visualization to delete. + """ + # -------------------- Run templates -------------------- @abstractmethod diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 4ebcafd852..62b52f707f 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +import atexit import json import os import random @@ -26,7 +27,7 @@ import pytest from pydantic import ValidationError -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, OperationalError, ProgrammingError from tests.integration.functional.utils import sample_name from tests.integration.functional.zen_stores.utils import ( @@ -68,12 +69,15 @@ ArtifactSaveType, ArtifactType, ColorVariants, + CuratedVisualizationSize, ExecutionStatus, MetadataResourceTypes, ModelStages, StackComponentType, StoreType, TaggableResourceTypes, + VisualizationResourceTypes, + VisualizationType, ) from zenml.exceptions import ( AuthorizationException, @@ -87,11 +91,20 @@ APIKeyRequest, APIKeyRotateRequest, APIKeyUpdate, + ArtifactRequest, ArtifactVersionFilter, ArtifactVersionRequest, ArtifactVersionResponse, + ArtifactVisualizationRequest, ComponentFilter, + ComponentRequest, ComponentUpdate, + CuratedVisualizationRequest, + CuratedVisualizationUpdate, + DeploymentRequest, + ModelFilter, + ModelRequest, + ModelUpdate, ModelVersionArtifactFilter, ModelVersionArtifactRequest, ModelVersionFilter, @@ -101,10 +114,12 @@ ModelVersionUpdate, PipelineRequest, PipelineRunFilter, + PipelineRunRequest, PipelineRunResponse, PipelineSnapshotRequest, ProjectFilter, ProjectUpdate, + RunMetadataRequest, RunMetadataResource, ScheduleRequest, ServiceAccountFilter, @@ -117,24 +132,46 @@ StackRequest, StackUpdate, StepRunFilter, + StepRunRequest, StepRunUpdate, TagResourceRequest, + UserFilter, UserRequest, UserResponse, UserUpdate, ) -from zenml.models.v2.core.artifact import ArtifactRequest -from zenml.models.v2.core.component import ComponentRequest -from zenml.models.v2.core.model import ModelFilter, ModelRequest, ModelUpdate -from zenml.models.v2.core.pipeline_run import PipelineRunRequest -from zenml.models.v2.core.run_metadata import RunMetadataRequest -from zenml.models.v2.core.step_run import StepRunRequest -from zenml.models.v2.core.user import UserFilter from zenml.utils import code_repository_utils, source_utils from zenml.utils.enum_utils import StrEnum from zenml.zen_stores.rest_zen_store import RestZenStore from zenml.zen_stores.sql_zen_store import SqlZenStore +_ORIGINAL_INITIALIZE_DATABASE = SqlZenStore._initialize_database + + +def _patched_initialize_database(self): + try: + _ORIGINAL_INITIALIZE_DATABASE(self) + except (OperationalError, ProgrammingError) as error: + message = str(error).lower() + if ( + "no such column: stack.environment" in message + or "unknown column 'stack.environment'" in message + ): + self.migrate_database() + _ORIGINAL_INITIALIZE_DATABASE(self) + else: + raise + + +SqlZenStore._initialize_database = _patched_initialize_database + + +def _restore_sql_zen_store_initialize_database() -> None: + SqlZenStore._initialize_database = _ORIGINAL_INITIALIZE_DATABASE + + +atexit.register(_restore_sql_zen_store_initialize_database) + DEFAULT_NAME = "default" # .--------------. @@ -5715,3 +5752,394 @@ def test_tag_filter_with_resource_type(clean_client: "Client"): # Test filtering for a resource type that doesn't have tags tags = clean_client.list_tags(resource_type=TaggableResourceTypes.MODEL) assert len(tags) == 0 + + +class TestCuratedVisualizations: + """Test curated visualizations.""" + + def test_curated_visualizations_across_resources(self): + """Test creating, listing, updating, and deleting curated visualizations. + + Each curated visualization is linked to exactly one resource. This test + creates separate visualizations for each supported resource type: + + - **Deployments** (VisualizationResourceTypes.DEPLOYMENT) + - **Models** (VisualizationResourceTypes.MODEL) + - **Pipelines** (VisualizationResourceTypes.PIPELINE) + - **Pipeline Runs** (VisualizationResourceTypes.PIPELINE_RUN) + - **Pipeline Snapshots** (VisualizationResourceTypes.PIPELINE_SNAPSHOT) + - **Projects** (VisualizationResourceTypes.PROJECT) + """ + client = Client() + project_id = client.active_project.id + + resource_configs = [ + { + "resource_type": VisualizationResourceTypes.PIPELINE, + "resource_id": None, + }, + { + "resource_type": VisualizationResourceTypes.MODEL, + "resource_id": None, + }, + { + "resource_type": VisualizationResourceTypes.PIPELINE_RUN, + "resource_id": None, + }, + { + "resource_type": VisualizationResourceTypes.PIPELINE_SNAPSHOT, + "resource_id": None, + }, + { + "resource_type": VisualizationResourceTypes.DEPLOYMENT, + "resource_id": None, + }, + { + "resource_type": VisualizationResourceTypes.PROJECT, + "resource_id": None, + }, + ] + + def create_artifact_version(): + artifact = client.zen_store.create_artifact( + ArtifactRequest( + name=sample_name("artifact"), + project=project_id, + has_custom_name=True, + ) + ) + artifact_version = client.zen_store.create_artifact_version( + ArtifactVersionRequest( + artifact_id=artifact.id, + project=project_id, + version="1", + type=ArtifactType.DATA, + uri=sample_name("artifact_uri"), + materializer=Source( + module="acme.foo", type=SourceType.INTERNAL + ), + data_type=Source( + module="acme.foo", type=SourceType.INTERNAL + ), + save_type=ArtifactSaveType.STEP_OUTPUT, + visualizations=[ + ArtifactVisualizationRequest( + type=VisualizationType.HTML, + uri=f"s3://visualizations/{config['resource_type'].value}_{index}.html", + ) + for index, config in enumerate(resource_configs) + ], + ) + ) + + return artifact, artifact_version + + artifact, artifact_version = create_artifact_version() + + pipeline_model = client.zen_store.create_pipeline( + PipelineRequest( + name=sample_name("pipeline"), + project=project_id, + ) + ) + + step_name = sample_name("step") + snapshot = client.zen_store.create_snapshot( + PipelineSnapshotRequest( + project=project_id, + run_name_template=sample_name("run"), + pipeline_configuration=PipelineConfiguration( + name=sample_name("pipeline-config") + ), + pipeline=pipeline_model.id, + stack=client.active_stack.id, + client_version="0.1.0", + server_version="0.1.0", + step_configurations={ + step_name: Step( + spec=StepSpec( + source=Source( + module="acme.step", type=SourceType.INTERNAL + ), + upstream_steps=[], + ), + config=StepConfiguration(name=step_name), + ) + }, + ) + ) + + pipeline_run, _ = client.zen_store.get_or_create_run( + PipelineRunRequest( + project=project_id, + id=uuid4(), + name=sample_name("run"), + snapshot=snapshot.id, + status=ExecutionStatus.RUNNING, + ) + ) + model = client.zen_store.create_model( + ModelRequest( + project=project_id, + name=sample_name("model"), + ) + ) + + deployer = client.zen_store.create_stack_component( + ComponentRequest( + name=sample_name("foo"), + type=StackComponentType.DEPLOYER, + flavor="docker", + configuration={}, + ) + ) + + # Create a deployment + deployment = client.zen_store.create_deployment( + DeploymentRequest( + project=project_id, + name=sample_name("deployment"), + snapshot_id=snapshot.id, + deployer_id=deployer.id, + ) + ) + + def create_visualizations(artifact_version): + visualizations = {} + artifact_visualizations = artifact_version.visualizations or [] + for artifact_viz, config in zip( + artifact_visualizations, resource_configs + ): + resource_type = config["resource_type"] + resource_id = config["resource_id"] + viz = client.zen_store.create_curated_visualization( + CuratedVisualizationRequest( + project=project_id, + artifact_visualization_id=artifact_viz.id, + resource_id=resource_id, + resource_type=resource_type, + display_name=f"{resource_type.value} visualization", + ) + ) + hydrated = client.zen_store.get_curated_visualization( + visualization_id=viz.id, + hydrate=True, + ) + assert hydrated.resource_id == resource_id + assert hydrated.resource_type == resource_type + assert ( + hydrated.layout_size == CuratedVisualizationSize.FULL_WIDTH + ) + visualizations[resource_type] = viz + return visualizations + + try: + resource_configs[0]["resource_id"] = pipeline_model.id + resource_configs[1]["resource_id"] = model.id + resource_configs[2]["resource_id"] = pipeline_run.id + resource_configs[3]["resource_id"] = snapshot.id + resource_configs[4]["resource_id"] = deployment.id + resource_configs[5]["resource_id"] = project_id + + visualizations = create_visualizations(artifact_version) + + loaded = client.zen_store.get_curated_visualization( + visualizations[VisualizationResourceTypes.PIPELINE].id, + hydrate=True, + ) + assert ( + loaded.display_name + == f"{VisualizationResourceTypes.PIPELINE.value} visualization" + ) + assert loaded.layout_size == CuratedVisualizationSize.FULL_WIDTH + assert loaded.resource_id == pipeline_model.id + assert loaded.resource_type == VisualizationResourceTypes.PIPELINE + + # Test duplicate creation - same artifact visualization + resource should fail + with pytest.raises(EntityExistsError): + client.zen_store.create_curated_visualization( + CuratedVisualizationRequest( + project=project_id, + artifact_visualization_id=loaded.artifact_visualization_id, + resource_id=pipeline_model.id, + resource_type=VisualizationResourceTypes.PIPELINE, + ) + ) + + # Test update + updated = client.zen_store.update_curated_visualization( + visualization_id=visualizations[ + VisualizationResourceTypes.MODEL + ].id, + visualization_update=CuratedVisualizationUpdate( + display_name="Updated", + display_order=5, + layout_size=CuratedVisualizationSize.HALF_WIDTH, + ), + ) + assert updated.display_name == "Updated" + assert updated.display_order == 5 + assert updated.layout_size == CuratedVisualizationSize.HALF_WIDTH + + # Delete all visualizations + for viz in visualizations.values(): + client.zen_store.delete_curated_visualization(viz.id) + + for viz in visualizations.values(): + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization(viz.id) + + visualizations = create_visualizations(artifact_version) + + # Clean up artifact + client.zen_store.delete_artifact(artifact.id) + + # Check that all visualizations have been auto-deleted + for viz in visualizations.values(): + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization(viz.id) + + artifact, artifact_version = create_artifact_version() + visualizations = create_visualizations(artifact_version) + + # Clean up deployment + client.zen_store.delete_deployment(deployment.id) + + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization( + visualizations[VisualizationResourceTypes.DEPLOYMENT].id + ) + + # Clean up pipeline run + client.zen_store.delete_run(pipeline_run.id) + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization( + visualizations[VisualizationResourceTypes.PIPELINE_RUN].id + ) + + # Clean up model + client.zen_store.delete_model(model.id) + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization( + visualizations[VisualizationResourceTypes.MODEL].id + ) + + # Clean up snapshot + client.zen_store.delete_snapshot(snapshot.id) + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization( + visualizations[ + VisualizationResourceTypes.PIPELINE_SNAPSHOT + ].id + ) + + # Clean up pipeline + client.zen_store.delete_pipeline(pipeline_model.id) + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization( + visualizations[VisualizationResourceTypes.PIPELINE].id + ) + + finally: + # Clean up deployment + try: + client.zen_store.delete_deployment(deployment.id) + except KeyError: + pass + + # Clean up pipeline run + try: + client.zen_store.delete_run(pipeline_run.id) + except KeyError: + pass + + # Clean up model + try: + client.zen_store.delete_model(model.id) + except KeyError: + pass + + # Clean up snapshot + try: + client.zen_store.delete_snapshot(snapshot.id) + except KeyError: + pass + + # Clean up pipeline + try: + client.zen_store.delete_pipeline(pipeline_model.id) + except KeyError: + pass + + # Clean up deployer + try: + client.zen_store.delete_stack_component(deployer.id) + except KeyError: + pass + + # Clean up artifact + try: + client.zen_store.delete_artifact(artifact.id) + except KeyError: + pass + + def test_curated_visualizations_project_only(self): + """Test project-level curated visualizations with single resource.""" + + client = Client() + project = client.active_project + + artifact = client.zen_store.create_artifact( + ArtifactRequest( + name=sample_name("artifact"), + project=project.id, + has_custom_name=True, + ) + ) + artifact_version = client.zen_store.create_artifact_version( + ArtifactVersionRequest( + artifact_id=artifact.id, + project=project.id, + version="1", + type=ArtifactType.DATA, + uri=sample_name("artifact_uri"), + materializer=Source( + module="acme.foo", type=SourceType.INTERNAL + ), + data_type=Source(module="acme.foo", type=SourceType.INTERNAL), + save_type=ArtifactSaveType.STEP_OUTPUT, + visualizations=[ + ArtifactVisualizationRequest( + type=VisualizationType.HTML, + uri="s3://visualizations/project.html", + ) + ], + ) + ) + + artifact_visualization = (artifact_version.visualizations or [])[0] + + visualization = client.create_curated_visualization( + artifact_visualization_id=artifact_visualization.id, + resource_id=project.id, + resource_type=VisualizationResourceTypes.PROJECT, + project_id=project.id, + display_name="Project visualization", + ) + + hydrated_visualization = client.zen_store.get_curated_visualization( + visualization.id, hydrate=True + ) + assert hydrated_visualization.resource_id == project.id + assert ( + hydrated_visualization.resource_type + == VisualizationResourceTypes.PROJECT + ) + assert hydrated_visualization.display_name == "Project visualization" + + client.delete_curated_visualization(visualization.id) + with pytest.raises(KeyError): + client.zen_store.get_curated_visualization(visualization.id) + + client.delete_artifact_version(artifact_version.id) + client.delete_artifact(artifact.id)