Skip to content

Commit fa3e8df

Browse files
[OpenVINO Backend] support while loop (#21369)
1 parent 9205298 commit fa3e8df

File tree

3 files changed

+111
-7
lines changed

3 files changed

+111
-7
lines changed

keras/src/backend/openvino/core.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,10 +692,111 @@ def while_loop(
692692
loop_vars,
693693
maximum_iterations=None,
694694
):
695-
raise NotImplementedError(
696-
"`while_loop` is not supported with openvino backend"
695+
def flatten_structure(data):
696+
if isinstance(data, dict):
697+
return [v for k in sorted(data) for v in flatten_structure(data[k])]
698+
elif isinstance(data, (tuple, list)):
699+
return [v for item in data for v in flatten_structure(item)]
700+
else:
701+
return [data]
702+
703+
def pack_structure(template, flat):
704+
if isinstance(template, dict):
705+
keys = sorted(template)
706+
packed = {}
707+
for k in keys:
708+
value, flat = pack_structure(template[k], flat)
709+
packed[k] = value
710+
return packed, flat
711+
elif isinstance(template, (tuple, list)):
712+
packed = []
713+
for item in template:
714+
value, flat = pack_structure(item, flat)
715+
packed.append(value)
716+
return (
717+
tuple(packed) if isinstance(template, tuple) else packed
718+
), flat
719+
else:
720+
return flat[0], flat[1:]
721+
722+
is_scalar_input = _is_scalar(loop_vars)
723+
724+
if is_scalar_input:
725+
loop_vars = (loop_vars,)
726+
elif isinstance(loop_vars, (list, np.ndarray)):
727+
loop_vars = tuple(loop_vars)
728+
else:
729+
assert isinstance(loop_vars, (tuple, dict)), (
730+
f"Unsupported type {type(loop_vars)} for loop_vars"
731+
)
732+
733+
flat_loop_vars = flatten_structure(loop_vars)
734+
loop_vars_ov = [get_ov_output(var) for var in flat_loop_vars]
735+
736+
maximum_iterations = (
737+
ov_opset.constant(-1, Type.i32).output(0)
738+
if maximum_iterations is None
739+
else get_ov_output(maximum_iterations)
697740
)
698741

742+
trip_count = maximum_iterations
743+
execution_condition = ov_opset.constant(True, Type.boolean).output(0)
744+
loop = ov_opset.loop(trip_count, execution_condition)
745+
746+
shapes = [var.get_partial_shape() for var in loop_vars_ov]
747+
types = [var.get_element_type() for var in loop_vars_ov]
748+
params = [
749+
ov_opset.parameter(shape, dtype) for shape, dtype in zip(shapes, types)
750+
]
751+
param_tensors = [OpenVINOKerasTensor(p.output(0)) for p in params]
752+
753+
packed_args, _ = pack_structure(loop_vars, param_tensors)
754+
if isinstance(packed_args, dict):
755+
body_out = body(packed_args)
756+
else:
757+
body_out = body(*packed_args)
758+
759+
if not isinstance(body_out, (list, tuple, dict)):
760+
body_out = (body_out,)
761+
762+
flat_body_out = flatten_structure(body_out)
763+
if isinstance(packed_args, dict):
764+
cond_output = get_ov_output(cond(body_out))
765+
else:
766+
cond_output = get_ov_output(cond(*body_out))
767+
768+
if len(cond_output.get_partial_shape()) != 0:
769+
raise ValueError(
770+
"`cond` function must return a scalar boolean value, "
771+
"but got shape {}".format(cond_output.get_partial_shape())
772+
)
773+
774+
for p, out in zip(params, flat_body_out):
775+
out_shape = get_ov_output(out).get_partial_shape()
776+
p.set_partial_shape(out_shape)
777+
778+
results = [cond_output] + [get_ov_output(x) for x in flat_body_out]
779+
body_func = Model(results=results, parameters=params)
780+
loop.set_function(body_func)
781+
loop.set_special_body_ports([-1, 0])
782+
783+
for param, init_val, next_val in zip(params, loop_vars_ov, flat_body_out):
784+
loop.set_merged_input(param, init_val, get_ov_output(next_val))
785+
786+
outputs_flat = [
787+
OpenVINOKerasTensor(loop.get_iter_value(get_ov_output(val)))
788+
for val in flat_body_out
789+
]
790+
final_output, _ = pack_structure(loop_vars, outputs_flat)
791+
792+
if is_scalar_input:
793+
if isinstance(final_output, tuple):
794+
return final_output[0]
795+
else:
796+
return final_output
797+
else:
798+
return final_output
799+
699800

700801
def fori_loop(lower, upper, body_fun, init_val):
701802
raise NotImplementedError(

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ CoreOpsCallsTests::test_scatter_update_basic_call
167167
CoreOpsCallsTests::test_slice_update_basic_call
168168
CoreOpsCallsTests::test_switch_basic_call
169169
CoreOpsCallsTests::test_unstack_basic_functionality
170-
CoreOpsCallsTests::test_while_loop_basic_functionality
171-
CoreOpsCallsTests::test_while_loop_with_max_iterations
172170
CoreOpsCorrectnessTest::test_associative_scan
173171
CoreOpsCorrectnessTest::test_cond
174172
CoreOpsCorrectnessTest::test_dynamic_slice
@@ -180,7 +178,6 @@ CoreOpsCorrectnessTest::test_slice_update
180178
CoreOpsCorrectnessTest::test_switch
181179
CoreOpsCorrectnessTest::test_unstack
182180
CoreOpsCorrectnessTest::test_vectorized_map
183-
CoreOpsCorrectnessTest::test_while_loop
184181
CoreOpsDtypeTest::test_convert_to_tensor0
185182
CoreOpsDtypeTest::test_convert_to_tensor1
186183
CoreOpsDtypeTest::test_convert_to_tensor2

keras/src/ops/core_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,10 @@ def body(i):
11131113
# Initial loop variable (i = 0)
11141114
loop_vars = (0,)
11151115
result = while_loop.call(loop_vars)
1116-
self.assertEqual(result[0], 5)
1116+
if backend.backend() == "openvino":
1117+
self.assertEqual(ops.convert_to_numpy(result[0]), 5)
1118+
else:
1119+
self.assertEqual(result[0], 5)
11171120

11181121
def test_while_loop_output_spec(self):
11191122
# Define dummy cond and body functions
@@ -1139,7 +1142,10 @@ def body(i):
11391142

11401143
while_loop = core.WhileLoop(cond, body, maximum_iterations=5)
11411144
result = while_loop.call((0,))
1142-
self.assertEqual(result[0], 5)
1145+
if backend.backend() == "openvino":
1146+
self.assertEqual(ops.convert_to_numpy(result[0]), 5)
1147+
else:
1148+
self.assertEqual(result[0], 5)
11431149

11441150
def test_whileloop_compute_output_spec(self):
11451151
# Define loop variables with different shapes and data types

0 commit comments

Comments
 (0)