Skip to content

Commit 384302f

Browse files
authored
Merge pull request #849 from deepmodeling/zjgemi
fix: support continue_on_success_ratio and continue_on_num_success fo…
2 parents da9ca61 + c779c5c commit 384302f

File tree

2 files changed

+17
-21
lines changed

2 files changed

+17
-21
lines changed

src/dflow/dag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@ def run(self, workflow_id=None, context=None):
269269
traceback.print_exc()
270270
self.tasks[j].phase = "Failed"
271271
if not self.tasks[j].continue_on_failed:
272+
if sys.version_info.minor >= 9:
273+
pool.shutdown(wait=False)
274+
else:
275+
pool.shutdown(wait=True)
272276
raise RuntimeError("Task %s failed" % self.tasks[j])
273277
else:
274278
for name, value in pars.items():

src/dflow/step.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ def merge_output_artifact(art, parent, layer=0):
874874
self.template = self.template.deepcopy()
875875

876876
def add_success_tag(templ):
877+
from .dag import DAG
877878
from .steps import Steps
878879
if isinstance(templ, ScriptOPTemplate):
879880
templ.outputs.parameters["dflow_success_tag"] = \
@@ -889,21 +890,10 @@ def add_success_tag(templ):
889890
elif isinstance(templ, PythonScriptOPTemplate):
890891
templ.script += "\nwith open('/tmp/outputs"\
891892
"/success_tag', 'w') as f:\n f.write('1')\n"
892-
elif isinstance(templ, Steps):
893-
last_step = templ.steps[-1]
894-
last_templ = last_step.template
895-
add_success_tag(last_templ)
896-
last_step.outputs.parameters["dflow_success_tag"] = \
897-
deepcopy(
898-
last_templ.outputs.parameters["dflow_success_tag"])
893+
elif isinstance(templ, (Steps, DAG)):
899894
templ.outputs.parameters["dflow_success_tag"] = \
900895
OutputParameter(
901-
value_from_parameter=last_step.outputs.parameters[
902-
"dflow_success_tag"], default="0")
903-
else:
904-
raise RuntimeError(
905-
"Unsupported type of OPTemplate for "
906-
"continue_on_num_success or continue_on_success_ratio")
896+
value_from_parameter="nonexist", default="1")
907897

908898
add_success_tag(self.template)
909899
self.outputs.parameters["dflow_success_tag"] = deepcopy(
@@ -1619,15 +1609,12 @@ def handle_expr(val, scope):
16191609
for name, par in self.outputs.parameters.items():
16201610
par.value = []
16211611
for ps in self.parallel_steps:
1622-
if not hasattr(ps.outputs.parameters[name], "value") and \
1623-
hasattr(ps.outputs.parameters[name], "default"):
1624-
value = ps.outputs.parameters[name].default
1625-
else:
1612+
if hasattr(ps.outputs.parameters[name], "value"):
16261613
value = ps.outputs.parameters[name].value
1627-
if isinstance(value, str):
1628-
par.value.append(value)
1629-
else:
1630-
par.value.append(jsonpickle.dumps(value))
1614+
if isinstance(value, str):
1615+
par.value.append(value)
1616+
else:
1617+
par.value.append(jsonpickle.dumps(value))
16311618
for name, art in self.outputs.artifacts.items():
16321619
for save in self.template.outputs.artifacts[name].save:
16331620
if isinstance(save, S3Artifact):
@@ -2017,6 +2004,11 @@ def exec_pod(self, scope, parameters, item=None):
20172004
else:
20182005
raise ValueError("Unsupported copy method for debug mode.")
20192006

2007+
# set default output parameters
2008+
for name, par in self.outputs.parameters.items():
2009+
if hasattr(par, "default"):
2010+
par.value = par.default
2011+
20202012
script_path = os.path.join(stepdir, "script")
20212013
if self.phase == "Pending":
20222014
# render variables in the script

0 commit comments

Comments
 (0)