Skip to content

Commit b12a8e7

Browse files
committed
fix: sub_path slices with with_sequence
fix: slices of HDF5Datasets Signed-off-by: zjgemi <liuxin_zijian@163.com>
1 parent f0945d2 commit b12a8e7

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

src/dflow/python/utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,15 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
7878
path_object = assemble_path_object(art_path)
7979
path_object = get_slices(path_object, prefix)
8080

81-
path_object = get_slices(path_object, slices)
82-
8381
sign_type = sign.type
8482
if getattr(sign_type, "__origin__", None) == Union:
8583
args = sign_type.__args__
8684
if HDF5Datasets in args:
87-
if isinstance(path_object, list) and all([isinstance(
88-
p, str) and p.endswith(".h5") for p in path_object]):
85+
if isinstance(path_object, list) and len(path_object) > 0 and all([
86+
isinstance(p, str) and p.endswith(".h5")
87+
for p in path_object]):
88+
sign_type = HDF5Datasets
89+
elif art_path.endswith(".h5"):
8990
sign_type = HDF5Datasets
9091
elif args[0] == HDF5Datasets:
9192
sign_type = args[1]
@@ -94,6 +95,8 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
9495

9596
if sign_type == HDF5Datasets:
9697
import h5py
98+
if os.path.isfile(art_path):
99+
path_object = [art_path]
97100
assert isinstance(path_object, list)
98101
res = None
99102
for path in path_object:
@@ -108,6 +111,10 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
108111
if res is None:
109112
res = {}
110113
res.update(datasets)
114+
res = get_slices(res, slices)
115+
else:
116+
path_object = get_slices(path_object, slices)
117+
111118
if sign_type in [str, Path]:
112119
if path_object is None or isinstance(path_object, str):
113120
res = path_object
@@ -146,7 +153,7 @@ def handle_input_artifact(name, sign, slices=None, data_root="/tmp",
146153
return None
147154

148155
_cls = res.__class__
149-
res = artifact_classes[_cls](res)
156+
res = artifact_classes.get(_cls, lambda x: x)(res)
150157
res.art_root = root
151158
return res
152159

@@ -230,6 +237,9 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
230237
d.attrs["type"] = "dir"
231238
d.attrs["path"] = str(v)
232239
d.attrs["dtype"] = "binary"
240+
elif isinstance(v, HDF5Dataset):
241+
d = f.create_dataset(s, data=v.dataset[()])
242+
d.attrs.update(v.dataset.attrs)
233243
else:
234244
d = f.create_dataset(s, data=v)
235245
d.attrs["type"] = "data"

src/dflow/step.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __init__(
484484
auto_loop_artifacts:
485485
self.template = self.template.deepcopy()
486486
sequence_format = self.with_sequence.format \
487-
if self.with_sequence is not None else "%d"
487+
if self.with_sequence is not None else None
488488
init_template = InitArtifactForSlices(
489489
self.template, self.util_image, self.util_command,
490490
self.util_image_pull_policy, self.key, sliced_output_artifact,
@@ -598,6 +598,13 @@ def merge_output_artifact(art, parent, layer=0):
598598
elif v is not None:
599599
self.prepare_step.set_artifacts({name: v})
600600
self.inputs.artifacts[name].sp = "{{item.%s}}" % name
601+
if self.with_sequence is not None:
602+
for par in self.inputs.parameters.values():
603+
if hasattr(par, "value") and isinstance(
604+
par.value, str):
605+
par.value = par.value.replace(
606+
"{{item}}", "{{item.order}}")
607+
self.with_sequence = None
601608
self.with_param = self.prepare_step.outputs.parameters[
602609
"dflow_slices_path"]
603610

src/dflow/util_ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class InitArtifactForSlices(PythonScriptOPTemplate):
1010
def __init__(self, template, image, command, image_pull_policy, key,
1111
sliced_output_artifact=None, sliced_input_artifact=None,
1212
sum_var=None, concat_var=None, auto_loop_artifacts=None,
13-
group_size=None, format="%d", post_script="",
13+
group_size=None, format=None, post_script="",
1414
tmp_root="/tmp"):
1515
name = template.name
1616
super().__init__(name="%s-init-artifact" % name, image=image,
@@ -171,7 +171,7 @@ def render_script(self):
171171
script += " continue\n"
172172
script += " group_dir = r'%s/outputs/artifacts/%s/"\
173173
"group_' + ('%s' %% i)\n" % (
174-
self.tmp_root, name, self.format)
174+
self.tmp_root, name, self.format or "%d")
175175
script += " os.makedirs(group_dir, exist_ok=True)\n"
176176
if self.template.slices.shuffle:
177177
script += " path_list = [path_list_%s[j] for j in "\
@@ -206,7 +206,10 @@ def render_script(self):
206206
else:
207207
script += "slices_path = []\n"
208208
script += "for i in range(len(path_list_%s)):\n" % required[0]
209-
script += " item = {'order': i}\n"
209+
if self.format:
210+
script += " item = {'order': '%s' %% i}\n" % self.format
211+
else:
212+
script += " item = {'order': i}\n"
210213
for i, name in enumerate(self.sliced_input_artifact):
211214
script += " item['%s'] = path_list_%s[i]"\
212215
"['dflow_list_item'] if path_list_%s else None\n" % (

0 commit comments

Comments
 (0)