From aef07430d69d0cf4bf11de56c92ffde4574b8d4e Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sat, 7 Jun 2025 04:38:00 +0300 Subject: [PATCH 1/2] [OpenVINO Backend] support slice_update --- keras/src/backend/openvino/core.py | 78 +++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 8ae342e27f3..4ae45c27849 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -657,9 +657,83 @@ def slice(inputs, start_indices, shape): def slice_update(inputs, start_indices, updates): - raise NotImplementedError( - "`slice_update` is not supported with openvino backend" + inputs = get_ov_output(inputs) + if isinstance(start_indices, (list, np.ndarray)): + start_indices = tuple(start_indices) + assert isinstance(start_indices, tuple), ( + "`slice_update` is not supported by openvino backend" + " for `start_indices` of type {}".format(type(start_indices)) ) + processed_start_indices = [] + for idx in start_indices: + val = get_ov_output(idx) + val_type = val.get_element_type() + if not val_type.is_integral(): + raise ValueError( + "`slice` is not supported by OpenVINO backend " + "for `start_indices` or `shape` with non-integer types" + ) + if val_type != Type.i32: + val = ov_opset.convert(val, Type.i32).output(0) + if len(val.get_partial_shape()) == 0: + val = ov_opset.unsqueeze( + val, ov_opset.constant(0, Type.i32) + ).output(0) + processed_start_indices.append(val) + start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0) + + rank = len(updates.shape) + ranges = [] + for dim in updates.shape: + r = ov_opset.range( + ov_opset.constant(0, Type.i32), + ov_opset.constant(dim, Type.i32), + ov_opset.constant(1, Type.i32), + output_type=Type.i32, + ) + ranges.append(r) + + broadcasted_ranges = [] + for i, r in enumerate(ranges): + shape = [1] * rank + shape[i] = updates.shape[i] + r_reshaped = ov_opset.reshape( + r, ov_opset.constant(shape, Type.i32), special_zero=False + ).output(0) + target_shape = ov_opset.constant(list(updates.shape), Type.i32) + r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0) + broadcasted_ranges.append(r_broadcasted) + + indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0) + + num_updates = 1 + for dim in updates.shape: + num_updates *= dim + new_shape = ov_opset.constant([rank, num_updates], Type.i32) + indices_reshaped = ov_opset.reshape( + indices_stack, new_shape, special_zero=False + ).output(0) + absolute_indices = ov_opset.transpose( + indices_reshaped, ov_opset.constant([1, 0], Type.i32) + ).output(0) + + start_indices_expanded = ov_opset.broadcast( + start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32) + ).output(0) + absolute_indices = ov_opset.add( + absolute_indices, start_indices_expanded + ).output(0) + + updates_tensor = get_ov_output(updates) + updates_flat = ov_opset.reshape( + updates_tensor, + ov_opset.constant([num_updates], Type.i32), + special_zero=False, + ).output(0) + updated = ov_opset.scatter_nd_update( + inputs, absolute_indices, updates_flat + ).output(0) + return OpenVINOKerasTensor(updated) def while_loop( From a285aa7600f14a909f497fe7c4052b7c515bbc3f Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Fri, 20 Jun 2025 22:14:36 +0300 Subject: [PATCH 2/2] enable tests for slice_update --- keras/src/backend/openvino/excluded_concrete_tests.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 5bbebf80dfb..92abf531fa3 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -164,17 +164,14 @@ CoreOpsCallsTests::test_map_basic_call CoreOpsCallsTests::test_scan_basic_call CoreOpsCallsTests::test_scatter_basic_call CoreOpsCallsTests::test_scatter_update_basic_call -CoreOpsCallsTests::test_slice_update_basic_call CoreOpsCallsTests::test_switch_basic_call CoreOpsCallsTests::test_unstack_basic_functionality CoreOpsCorrectnessTest::test_associative_scan CoreOpsCorrectnessTest::test_cond -CoreOpsCorrectnessTest::test_dynamic_slice CoreOpsCorrectnessTest::test_fori_loop CoreOpsCorrectnessTest::test_map CoreOpsCorrectnessTest::test_scan CoreOpsCorrectnessTest::test_scatter -CoreOpsCorrectnessTest::test_slice_update CoreOpsCorrectnessTest::test_switch CoreOpsCorrectnessTest::test_unstack CoreOpsCorrectnessTest::test_vectorized_map