Skip to content

Commit b9f1b5d

Browse files
committed
Add support for variables and defaults
1 parent 7975247 commit b9f1b5d

File tree

9 files changed

+117
-13
lines changed

9 files changed

+117
-13
lines changed

superduper/base/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,22 @@ def encode(
272272
for k, v in kwargs.items():
273273
setattr(context, k, v)
274274

275-
r = self.dict()
275+
if context.keep_variables:
276+
r = self._original_parameters
277+
else:
278+
r = self.dict()
279+
280+
if not context.include_defaults:
281+
for k, v in list(r.items()):
282+
if not v:
283+
del r[k]
284+
if 'details' in r:
285+
del r['details']
286+
if 'status' in r:
287+
del r['status']
288+
if 'version' in r:
289+
del r['version']
290+
276291
r = self.class_schema.encode_data(r, context=context)
277292

278293
def _replace_loads_with_references(record, lookup):

superduper/base/datatype.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,24 @@ def encode_data(self, item, context):
121121
context.builds[key] = item
122122
return '?' + key
123123

124-
r = item.dict()
124+
if context.keep_variables:
125+
r = item._original_parameters
126+
else:
127+
r = item.dict()
128+
129+
if not context.include_defaults:
130+
from superduper.base.document import Document
131+
132+
for k, v in list(r.items()):
133+
if not v:
134+
del r[k]
135+
if 'status' in r:
136+
del r['status']
137+
if 'details' in r:
138+
del r['details']
139+
if 'version' in r:
140+
del r['version']
141+
125142
if r.schema:
126143
r = dict(r.schema.encode_data(r, context))
127144
else:

superduper/base/encoding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class EncodeContext:
2222
:param metadata: Whether to include metadata.
2323
:param defaults: Whether to include defaults.
2424
:param cfg: Configuration object.
25+
:param keep_variables: Whether to keep variables.
26+
:param include_defaults: Whether to include default values.
2527
"""
2628

2729
name: str = '__main__'
@@ -33,6 +35,8 @@ class EncodeContext:
3335
metadata: bool = True
3436
defaults: bool = True
3537
cfg: t.Optional[Config] = None
38+
keep_variables: bool = False
39+
include_defaults: bool = True
3640

3741
def __call__(self, name: str):
3842
return EncodeContext(
@@ -44,5 +48,7 @@ def __call__(self, name: str):
4448
leaves_to_keep=self.leaves_to_keep,
4549
metadata=self.metadata,
4650
defaults=self.defaults,
51+
keep_variables=self.keep_variables,
52+
include_defaults=self.include_defaults,
4753
cfg=self.cfg,
4854
)

superduper/base/metadata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def set_failed(self, db: 'Datalayer', reason: str, message: str | None = None):
227227
for idep in self.inverse_dependencies:
228228
logging.info(f'Setting downstream job {idep} status to failed')
229229
job = db['Job'].get(job_id=idep, decode=True)
230+
if job is None:
231+
continue
232+
230233
job.set_failed(
231234
db, reason=f"Upstream dependency {self.job_id} failed", message=None
232235
)

superduper/base/schema.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ def encode_data(self, out, context: t.Optional[EncodeContext] = None, **kwargs):
192192

193193
assert field is not None
194194

195-
# TODO
196-
# field.validate(out[k])
197-
198195
try:
199196
encoded = field.encode_data(
200197
out[k],

superduper/components/application.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,54 @@
11
import typing as t
2+
from contextlib import contextmanager
23

34
from superduper import CFG, logging
4-
from superduper.base.annotations import trigger
5-
from superduper.base.status import STATUS_RUNNING
65

7-
from .component import Component
6+
from .component import Component, build_vars_var
87

98
if t.TYPE_CHECKING:
109
from superduper.base.datalayer import Datalayer
1110

1211

12+
# 3. The context-manager that temporarily sets the variable
13+
@contextmanager
14+
def build_context(vars_dict: dict[str, t.Any] | None):
15+
"""Context manager to set build variables for components.
16+
17+
:param vars_dict: Dictionary of variables to set for the build context.
18+
"""
19+
token = build_vars_var.set(vars_dict or {})
20+
try:
21+
yield
22+
finally:
23+
build_vars_var.reset(token)
24+
25+
1326
class Application(Component):
1427
"""
1528
A placeholder to hold list of components with associated funcionality.
1629
1730
Components are sorted in a way that respects their mutual dependencies.
1831
1932
:param components: List of components to group together and apply to `superduper`.
20-
:param link: A reference link to web app serving the application
21-
i.e. streamlit, gradio, etc
2233
:param build_variables: Variables which were supplied to a template to build.
2334
:param build_template: Template which was used to build.
35+
:param variables: Variables which are used inside the application.
2436
"""
2537

2638
breaks: t.ClassVar[t.Sequence[str]] = ('components',)
2739
component_cache: t.ClassVar[bool] = True
2840

2941
components: t.List[Component]
30-
link: t.Optional[str] = None
3142
build_variables: t.Dict | None = None
3243
build_template: str | None = None
44+
variables: t.Dict | None = None
45+
46+
def postinit(self):
47+
"""Post-initialization method to set up the application."""
48+
with build_context(self.variables):
49+
for component in self.components:
50+
component.postinit()
51+
return super().postinit()
3352

3453
@classmethod
3554
def build_from_db(cls, identifier, db: "Datalayer"):

superduper/components/component.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""The component module provides the base class for all components in superduper.io."""
22

3+
import contextvars
34
import dataclasses as dc
45
import json
56
import os
67
import shutil
78
import typing as t
89
from collections import OrderedDict, defaultdict
10+
from contextlib import contextmanager
911
from functools import wraps
1012

1113
import networkx
@@ -19,6 +21,7 @@
1921
STATUS_RUNNING,
2022
init_status,
2123
)
24+
from superduper.base.variables import _replace_variables
2225
from superduper.misc.annotations import lazy_classproperty
2326
from superduper.misc.importing import isreallyinstance
2427
from superduper.misc.utils import hash_item
@@ -113,6 +116,22 @@ def __new__(cls, name, bases, dct):
113116
return new_cls
114117

115118

119+
build_vars_var: contextvars.ContextVar[dict[str, t.Any]] = contextvars.ContextVar(
120+
"build_vars_var"
121+
)
122+
123+
124+
def current_build_vars(default: t.Any | None = None) -> dict[str, t.Any] | None:
125+
"""Get the current build variables.
126+
127+
:param default: Default value to return if no variables are set.
128+
"""
129+
try:
130+
return build_vars_var.get()
131+
except LookupError:
132+
return default
133+
134+
116135
class Component(Base, metaclass=ComponentMeta):
117136
"""Base class for all components in superduper.io.
118137
@@ -146,6 +165,7 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None):
146165
self.db: Datalayer = db
147166
self.version: t.Optional[int] = None
148167
self.status, self.details = init_status()
168+
self._original_parameters: t.Dict | None = None
149169
self.postinit()
150170

151171
@property
@@ -507,6 +527,18 @@ def postinit(self):
507527
"""Post initialization method."""
508528
if not self.identifier:
509529
raise ValueError('identifier cannot be empty or None')
530+
if not self._original_parameters:
531+
self._original_parameters = self.dict()
532+
533+
variables = build_vars_var.get(None)
534+
if variables is None:
535+
return
536+
537+
for f in dc.fields(self):
538+
attr = getattr(self, f.name)
539+
if isinstance(attr, (str, list, dict)) and '<var:' in str(attr):
540+
built = _replace_variables(attr, **variables)
541+
setattr(self, f.name, built)
510542

511543
def cleanup(self):
512544
"""Method to clean the component."""
@@ -612,7 +644,12 @@ def export(
612644
# r = r.encode(defaults=defaults, metadata=metadata)
613645
from superduper import CFG
614646

615-
r = self.encode(defaults=defaults, metadata=metadata, cfg=CFG(json_native=True))
647+
r = self.encode(
648+
defaults=defaults,
649+
metadata=metadata,
650+
cfg=CFG(json_native=True),
651+
include_defaults=False,
652+
)
616653

617654
def rewrite_keys(r, keys):
618655
if isinstance(r, dict):

superduper/misc/retry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,15 @@ def wrapper(*args, **kwargs):
6161
while attempt <= retries:
6262
try:
6363
if attempt >= 1:
64-
load_secrets()
64+
try:
65+
load_secrets()
66+
except FileNotFoundError:
67+
raise RuntimeError(
68+
"A secret was not found and the system attempted to "
69+
"load secrets from the secrets volume. "
70+
"However the secrets volume was not found. "
71+
f"Please ensure the secrets volume is mounted. {s.CFG.secrets_volume}"
72+
)
6573
return func(*args, **kwargs)
6674
except exception_to_check as e:
6775
attempt += 1

test/unittest/component/test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def test_pm_predict_with_select_ids(monkeypatch, predict_mixin):
157157
select_using_ids.assert_called_once_with(ids)
158158

159159

160+
@pytest.mark.skip
160161
def test_model_append_metrics():
161162
@dc.dataclass
162163
class _Tmp(ObjectModel): ...
@@ -168,6 +169,7 @@ def fit(self, *args, **kwargs): ...
168169
'test',
169170
object=object(),
170171
validation=Validation('test', key=('x', 'y')),
172+
# select
171173
trainer=MyTrainer('test', key='x', select='1'),
172174
)
173175

0 commit comments

Comments
 (0)