@@ -657,9 +657,77 @@ def slice(inputs, start_indices, shape):
657
657
658
658
659
659
def slice_update (inputs , start_indices , updates ):
660
- raise NotImplementedError (
661
- "`slice_update` is not supported with openvino backend"
662
- )
660
+ inputs = get_ov_output (inputs )
661
+ processed_start_indices = []
662
+ for idx in start_indices :
663
+ val = get_ov_output (idx )
664
+ val_type = val .get_element_type ()
665
+ if not val_type .is_integral ():
666
+ raise ValueError (
667
+ "`slice` is not supported by OpenVINO backend "
668
+ "for `start_indices` or `shape` with non-integer types"
669
+ )
670
+ if val_type != Type .i32 :
671
+ val = ov_opset .convert (val , Type .i32 ).output (0 )
672
+ if len (val .get_partial_shape ()) == 0 :
673
+ val = ov_opset .unsqueeze (
674
+ val , ov_opset .constant (0 , Type .i32 )
675
+ ).output (0 )
676
+ processed_start_indices .append (val )
677
+ start_indices_tensor = ov_opset .concat (processed_start_indices , axis = 0 )
678
+
679
+ rank = len (updates .shape )
680
+ ranges = []
681
+ for dim in updates .shape :
682
+ r = ov_opset .range (
683
+ ov_opset .constant (0 , Type .i32 ),
684
+ ov_opset .constant (dim , Type .i32 ),
685
+ ov_opset .constant (1 , Type .i32 ),
686
+ output_type = Type .i32 ,
687
+ )
688
+ ranges .append (r )
689
+
690
+ broadcasted_ranges = []
691
+ for i , r in enumerate (ranges ):
692
+ shape = [1 ] * rank
693
+ shape [i ] = updates .shape [i ]
694
+ r_reshaped = ov_opset .reshape (
695
+ r , ov_opset .constant (shape , Type .i32 ), special_zero = False
696
+ ).output (0 )
697
+ target_shape = ov_opset .constant (list (updates .shape ), Type .i32 )
698
+ r_broadcasted = ov_opset .broadcast (r_reshaped , target_shape ).output (0 )
699
+ broadcasted_ranges .append (r_broadcasted )
700
+
701
+ indices_stack = ov_opset .concat (broadcasted_ranges , axis = 0 ).output (0 )
702
+
703
+ num_updates = 1
704
+ for dim in updates .shape :
705
+ num_updates *= dim
706
+ new_shape = ov_opset .constant ([rank , num_updates ], Type .i32 )
707
+ indices_reshaped = ov_opset .reshape (
708
+ indices_stack , new_shape , special_zero = False
709
+ ).output (0 )
710
+ absolute_indices = ov_opset .transpose (
711
+ indices_reshaped , ov_opset .constant ([1 , 0 ], Type .i32 )
712
+ ).output (0 )
713
+
714
+ start_indices_expanded = ov_opset .broadcast (
715
+ start_indices_tensor , ov_opset .constant ([num_updates , rank ], Type .i32 )
716
+ ).output (0 )
717
+ absolute_indices = ov_opset .add (
718
+ absolute_indices , start_indices_expanded
719
+ ).output (0 )
720
+
721
+ updates_tensor = get_ov_output (updates )
722
+ updates_flat = ov_opset .reshape (
723
+ updates_tensor ,
724
+ ov_opset .constant ([num_updates ], Type .i32 ),
725
+ special_zero = False ,
726
+ ).output (0 )
727
+ updated = ov_opset .scatter_nd_update (
728
+ inputs , absolute_indices , updates_flat
729
+ ).output (0 )
730
+ return OpenVINOKerasTensor (updated )
663
731
664
732
665
733
def while_loop (
0 commit comments