@@ -692,10 +692,111 @@ def while_loop(
692
692
loop_vars ,
693
693
maximum_iterations = None ,
694
694
):
695
- raise NotImplementedError (
696
- "`while_loop` is not supported with openvino backend"
695
+ def flatten_structure (data ):
696
+ if isinstance (data , dict ):
697
+ return [v for k in sorted (data ) for v in flatten_structure (data [k ])]
698
+ elif isinstance (data , (tuple , list )):
699
+ return [v for item in data for v in flatten_structure (item )]
700
+ else :
701
+ return [data ]
702
+
703
+ def pack_structure (template , flat ):
704
+ if isinstance (template , dict ):
705
+ keys = sorted (template )
706
+ packed = {}
707
+ for k in keys :
708
+ value , flat = pack_structure (template [k ], flat )
709
+ packed [k ] = value
710
+ return packed , flat
711
+ elif isinstance (template , (tuple , list )):
712
+ packed = []
713
+ for item in template :
714
+ value , flat = pack_structure (item , flat )
715
+ packed .append (value )
716
+ return (
717
+ tuple (packed ) if isinstance (template , tuple ) else packed
718
+ ), flat
719
+ else :
720
+ return flat [0 ], flat [1 :]
721
+
722
+ is_scalar_input = _is_scalar (loop_vars )
723
+
724
+ if is_scalar_input :
725
+ loop_vars = (loop_vars ,)
726
+ elif isinstance (loop_vars , (list , np .ndarray )):
727
+ loop_vars = tuple (loop_vars )
728
+ else :
729
+ assert isinstance (loop_vars , (tuple , dict )), (
730
+ f"Unsupported type { type (loop_vars )} for loop_vars"
731
+ )
732
+
733
+ flat_loop_vars = flatten_structure (loop_vars )
734
+ loop_vars_ov = [get_ov_output (var ) for var in flat_loop_vars ]
735
+
736
+ maximum_iterations = (
737
+ ov_opset .constant (- 1 , Type .i32 ).output (0 )
738
+ if maximum_iterations is None
739
+ else get_ov_output (maximum_iterations )
697
740
)
698
741
742
+ trip_count = maximum_iterations
743
+ execution_condition = ov_opset .constant (True , Type .boolean ).output (0 )
744
+ loop = ov_opset .loop (trip_count , execution_condition )
745
+
746
+ shapes = [var .get_partial_shape () for var in loop_vars_ov ]
747
+ types = [var .get_element_type () for var in loop_vars_ov ]
748
+ params = [
749
+ ov_opset .parameter (shape , dtype ) for shape , dtype in zip (shapes , types )
750
+ ]
751
+ param_tensors = [OpenVINOKerasTensor (p .output (0 )) for p in params ]
752
+
753
+ packed_args , _ = pack_structure (loop_vars , param_tensors )
754
+ if isinstance (packed_args , dict ):
755
+ body_out = body (packed_args )
756
+ else :
757
+ body_out = body (* packed_args )
758
+
759
+ if not isinstance (body_out , (list , tuple , dict )):
760
+ body_out = (body_out ,)
761
+
762
+ flat_body_out = flatten_structure (body_out )
763
+ if isinstance (packed_args , dict ):
764
+ cond_output = get_ov_output (cond (body_out ))
765
+ else :
766
+ cond_output = get_ov_output (cond (* body_out ))
767
+
768
+ if len (cond_output .get_partial_shape ()) != 0 :
769
+ raise ValueError (
770
+ "`cond` function must return a scalar boolean value, "
771
+ "but got shape {}" .format (cond_output .get_partial_shape ())
772
+ )
773
+
774
+ for p , out in zip (params , flat_body_out ):
775
+ out_shape = get_ov_output (out ).get_partial_shape ()
776
+ p .set_partial_shape (out_shape )
777
+
778
+ results = [cond_output ] + [get_ov_output (x ) for x in flat_body_out ]
779
+ body_func = Model (results = results , parameters = params )
780
+ loop .set_function (body_func )
781
+ loop .set_special_body_ports ([- 1 , 0 ])
782
+
783
+ for param , init_val , next_val in zip (params , loop_vars_ov , flat_body_out ):
784
+ loop .set_merged_input (param , init_val , get_ov_output (next_val ))
785
+
786
+ outputs_flat = [
787
+ OpenVINOKerasTensor (loop .get_iter_value (get_ov_output (val )))
788
+ for val in flat_body_out
789
+ ]
790
+ final_output , _ = pack_structure (loop_vars , outputs_flat )
791
+
792
+ if is_scalar_input :
793
+ if isinstance (final_output , tuple ):
794
+ return final_output [0 ]
795
+ else :
796
+ return final_output
797
+ else :
798
+ return final_output
799
+
699
800
700
801
def fori_loop (lower , upper , body_fun , init_val ):
701
802
raise NotImplementedError (
0 commit comments