Skip to content

Commit e14b036

Browse files
Feature:4015 Allow both infer & explicit options
1 parent cf5e542 commit e14b036

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

docs/book/how-to/metadata/metadata.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,38 @@ def get_train_test_datasets():
261261
```
262262

263263
Keep in mind that when using the `infer_artifacts` option, the `bulk_log_metadata` function logs metadata to all output artifacts of the step.
264+
When logging metadata, you may need the option to use `infer` options in combination with identifier references. For instance, you may want
265+
to log metadata to a step's outputs but also to its inputs. The `bulk_log_metadata` function enables you to use both options in one go:
266+
267+
```python
268+
from zenml import bulk_log_metadata, get_step_context, step
269+
from zenml.models import ArtifactVersionIdentifier
270+
271+
272+
def calculate_metrics(model, test_dataset):
273+
...
274+
275+
276+
def summarize_metrics(metrics_report):
277+
...
278+
279+
280+
@step
281+
def model_evaluation(test_dataset, model):
282+
metrics_report = calculate_metrics(model, test_dataset)
283+
284+
slim_metrics_version = summarize_metrics(metrics_report)
285+
286+
bulk_log_metadata(
287+
metadata=slim_metrics_version,
288+
infer_artifacts=True, # log metadata for outputs
289+
artifact_versions=[
290+
ArtifactVersionIdentifier(id=get_step_context().inputs["model"].id)
291+
] # log metadata for the model input
292+
)
293+
294+
return metrics_report
295+
```
264296

265297
### Performance improvements hints
266298

src/zenml/utils/metadata_utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def bulk_log_metadata(
395395
infer_artifacts: Flag - when enabled infer artifact to log metadata for from step context.
396396
397397
Raises:
398-
ValueError: If options are not passed correctly (infer options with explicit declarations) or
398+
ValueError: If options are not passed correctly (empty metadata or no identifier options) or
399399
invocation with `infer` options is done outside of a step context.
400400
"""
401401
client = Client()
@@ -420,16 +420,6 @@ def bulk_log_metadata(
420420
"You must select at least one entity to log metadata to."
421421
)
422422

423-
if infer_models and model_versions:
424-
raise ValueError(
425-
"You can either specify model versions or use the infer option."
426-
)
427-
428-
if infer_artifacts and artifact_versions:
429-
raise ValueError(
430-
"You can either specify artifact versions or use the infer option."
431-
)
432-
433423
try:
434424
step_context = get_step_context()
435425
except RuntimeError:

tests/unit/utils/test_metadata_utils.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,6 @@
1515

1616

1717
def test_bulk_log_metadata_validations(monkeypatch):
18-
# test one can not specify both explicit and infer options for models
19-
20-
with pytest.raises(ValueError):
21-
bulk_log_metadata(
22-
metadata={"x": 1},
23-
infer_models=True,
24-
model_versions=[ModelVersionIdentifier(id=uuid4())],
25-
)
26-
27-
# test one can not specify both explicit and infer options for a
28-
29-
with pytest.raises(ValueError):
30-
bulk_log_metadata(
31-
metadata={"x": 1},
32-
infer_artifacts=True,
33-
model_versions=[ModelVersionIdentifier(id=uuid4())],
34-
)
35-
3618
def boom():
3719
raise RuntimeError("boom!")
3820

@@ -167,3 +149,38 @@ def test_bulk_log_metadata_infer_model(monkeypatch):
167149
assert (
168150
len(mock_client.create_run_metadata.call_args.kwargs["resources"]) == 1
169151
)
152+
153+
154+
def test_combined_infer_with_explicit_options(monkeypatch):
155+
mock_step_context = MagicMock(
156+
model_version=MagicMock(id=uuid4()),
157+
_outputs={"a": MagicMock(id=uuid4()), "b": MagicMock(id=uuid4())},
158+
)
159+
160+
mock_client = MagicMock()
161+
mock_client.create_run_metadata = MagicMock()
162+
163+
with monkeypatch.context() as m:
164+
m.setattr(metadata_utils, "Client", lambda: mock_client)
165+
m.setattr(
166+
metadata_utils, "get_step_context", lambda: mock_step_context
167+
)
168+
169+
bulk_log_metadata(
170+
metadata={"x": 1},
171+
infer_models=True,
172+
infer_artifacts=True,
173+
model_versions=[ModelVersionIdentifier(id=uuid4())],
174+
artifact_versions=[ArtifactVersionIdentifier(id=uuid4())],
175+
)
176+
177+
assert mock_client.create_run_metadata.call_count == 1
178+
179+
assert mock_client.create_run_metadata.call_args.kwargs[
180+
"metadata"
181+
] == {"x": 1}
182+
assert (
183+
len(mock_client.create_run_metadata.call_args.kwargs["resources"])
184+
== 3
185+
)
186+
assert mock_step_context.add_output_metadata.call_count == 2

0 commit comments

Comments
 (0)