From 47c145c718e918e9ce8beb5b25bacf85c3fd1025 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Sun, 8 Jun 2025 20:55:26 +0300 Subject: [PATCH 1/2] [OpenVINO Backend] update getitem --- keras/src/backend/openvino/core.py | 164 ++++++++++++++++-- keras/src/backend/openvino/excluded_tests.txt | 1 - 2 files changed, 145 insertions(+), 20 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index 8ae342e27f3..a6fea9732a7 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -1,3 +1,4 @@ +import builtins import contextlib import warnings @@ -308,30 +309,155 @@ def __ne__(self, other): return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0)) def __getitem__(self, indices): - # now it has limited functionaly - # and supports only a case with one integer index in indices - # other indices must be None data = self.output - axis = [] - gather_index = None - if isinstance(indices, int): + rank = len(data.get_partial_shape()) + axes, gather_indices_nodes = [], [] + slice_axes, slice_starts, slice_ends, slice_steps = [], [], [], [] + unsqueeze_axes = [] + + if not isinstance(indices, tuple): indices = (indices,) - assert isinstance(indices, tuple), "only tuple is supported" + + if any(i is Ellipsis for i in indices): + ellipsis_pos = indices.index(Ellipsis) + num_specified = sum( + i is not Ellipsis and i is not None for i in indices + ) + num_missing = rank - num_specified + indices = ( + indices[:ellipsis_pos] + + (builtins.slice(None),) * num_missing + + indices[ellipsis_pos + 1 :] + ) + + def count_unsqueeze_before(dim): + return sum(1 for i in range(dim) if indices[i] is None) + + partial_shape = data.get_partial_shape() + for dim, index in enumerate(indices): - if isinstance(index, int): - axis.append(dim) - gather_index = ov_opset.constant(index, Type.i32) + if isinstance(index, bool): + raise ValueError( + "OpenVINO backend does not support boolean indexing" + ) + elif isinstance(index, int): + actual_dim = dim - count_unsqueeze_before(dim) + if not (0 <= actual_dim < rank): + raise IndexError( + f"Index {index} is out of bounds for " + "axis {dim} with rank {rank}" + ) + length = partial_shape[actual_dim].get_length() + idx_value = index if index >= 0 else length + index + axes.append(dim) + gather_indices_nodes.append( + ov_opset.constant([idx_value], Type.i32).output(0) + ) + elif isinstance(index, builtins.slice): + if index == builtins.slice(None): + continue + if index.step is not None and index.step < 0: + raise ValueError("OpenVINO doesn't support negative steps") + slice_axes.append(dim) + slice_starts.append(0 if index.start is None else index.start) + slice_ends.append( + 2**31 - 1 if index.stop is None else index.stop + ) + slice_steps.append(1 if index.step is None else index.step) + elif index is None: + unsqueeze_axes.append(dim) + elif isinstance(index, OpenVINOKerasTensor): + index = get_ov_output(index) + index_type = index.get_element_type() + index_shape = index.get_partial_shape() + if index_type == Type.boolean or not index_type.is_integral(): + raise ValueError( + "OpenVINO backend does not " + "support {index_type} indexing" + ) + axes.append(dim) + if len(index_shape) > 1: + raise ValueError( + "OpenVINO backend does not " + "support multi-dimensional indexing" + ) + if index_type != Type.i32: + index = ov_opset.convert(index, Type.i32).output(0) + shape_tensor = ov_opset.shape_of(data) + axis_i32 = ov_opset.constant([dim], dtype=Type.i32) + dim_size = ov_opset.gather( + shape_tensor, axis_i32, ov_opset.constant(0, Type.i32) + ) + dim_size = ov_opset.convert(dim_size, Type.i32) + zero = ov_opset.constant(0, Type.i32) + is_negative = ov_opset.less(index, zero) + adjusted_index = ov_opset.add(index, dim_size) + index = ov_opset.select( + is_negative, adjusted_index, index + ).output(0) + index_shape = index.get_partial_shape() + if len(index_shape) == 0: + index = ov_opset.unsqueeze( + index, ov_opset.constant(0, Type.i32) + ).output(0) + gather_indices_nodes.append(index) else: - assert ( - index.start is None - and index.stop is None - and index.step is None + raise ValueError( + f"Unsupported index type {type(index)} " + "in OpenVINOKerasTensor.__getitem__" ) - assert len(axis) == 1, "axis must contain one element" - axis = ov_opset.constant(axis, Type.i32) - return OpenVINOKerasTensor( - ov_opset.gather(data, gather_index, axis).output(0) - ) + + if slice_axes: + step = ov_opset.constant(slice_steps, Type.i32).output(0) + start = ov_opset.constant(slice_starts, Type.i32).output(0) + stop = ov_opset.constant(slice_ends, Type.i32).output(0) + adjusted_slice_axes = [ + ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax) + for ax in slice_axes + ] + axes_const = ov_opset.constant( + adjusted_slice_axes, Type.i32 + ).output(0) + data = ov_opset.slice(data, start, stop, step, axes_const).output(0) + + if axes: + gather_indices_const = ( + gather_indices_nodes[0] + if len(gather_indices_nodes) == 1 + else ov_opset.concat(gather_indices_nodes, axis=0).output(0) + ) + adjusted_axes = [ + ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax) + for ax in axes + ] + if len(axes) == 1: + data = ov_opset.gather( + data, gather_indices_const, adjusted_axes[0] + ).output(0) + data = ov_opset.squeeze(data, adjusted_axes[0]).output(0) + else: + rank = len(data.get_partial_shape()) + remaining_axes = [ + i for i in range(rank) if i not in adjusted_axes + ] + perm = ov_opset.constant( + adjusted_axes + remaining_axes, Type.i32 + ) + data = ov_opset.transpose(data, perm).output(0) + data = ov_opset.gather_nd(data, gather_indices_const).output(0) + + if unsqueeze_axes: + adjusted_unsqueeze = [] + for ax in unsqueeze_axes: + ax -= sum(1 for s in axes if s < ax) + ax -= sum(1 for s in slice_axes if s < ax) + adjusted_unsqueeze.append(ax) + unsqueeze_const = ov_opset.constant( + adjusted_unsqueeze, Type.i32 + ).output(0) + data = ov_opset.unsqueeze(data, unsqueeze_const).output(0) + + return OpenVINOKerasTensor(data) def __len__(self): ov_output = self.output diff --git a/keras/src/backend/openvino/excluded_tests.txt b/keras/src/backend/openvino/excluded_tests.txt index 545453dc9cd..ed97442338e 100644 --- a/keras/src/backend/openvino/excluded_tests.txt +++ b/keras/src/backend/openvino/excluded_tests.txt @@ -1,6 +1,5 @@ keras/src/activations keras/src/backend/common/dtypes_test.py -keras/src/backend/common/variables_test.py keras/src/callbacks/early_stopping_test.py keras/src/dtype_policies/dtype_policy_map_test.py keras/src/layers/attention From 48e69dbff86065a40eb7c0504a355312e3e9d7b9 Mon Sep 17 00:00:00 2001 From: Mohamed-Ashraf273 Date: Thu, 19 Jun 2025 19:02:24 +0300 Subject: [PATCH 2/2] [OpenVINO backend] update getitem --- keras/src/backend/openvino/core.py | 41 +++---- keras/src/ops/core_test.py | 166 +++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 19 deletions(-) diff --git a/keras/src/backend/openvino/core.py b/keras/src/backend/openvino/core.py index a6fea9732a7..1ebbc605d4f 100644 --- a/keras/src/backend/openvino/core.py +++ b/keras/src/backend/openvino/core.py @@ -333,26 +333,36 @@ def __getitem__(self, indices): def count_unsqueeze_before(dim): return sum(1 for i in range(dim) if indices[i] is None) - partial_shape = data.get_partial_shape() + partial_shape = ov_opset.shape_of(data, Type.i32) + zero_const = ov_opset.constant(0, Type.i32) for dim, index in enumerate(indices): if isinstance(index, bool): raise ValueError( "OpenVINO backend does not support boolean indexing" ) - elif isinstance(index, int): + elif isinstance(index, (int, np.integer)): + if isinstance(index, np.integer): + index = int(index) actual_dim = dim - count_unsqueeze_before(dim) if not (0 <= actual_dim < rank): raise IndexError( f"Index {index} is out of bounds for " "axis {dim} with rank {rank}" ) - length = partial_shape[actual_dim].get_length() - idx_value = index if index >= 0 else length + index - axes.append(dim) - gather_indices_nodes.append( - ov_opset.constant([idx_value], Type.i32).output(0) + length = ov_opset.gather( + partial_shape, + ov_opset.constant([actual_dim], Type.i32), + zero_const, ) + if index >= 0: + idx_value = ov_opset.constant([index], Type.i32) + else: + idx_value = ov_opset.add( + ov_opset.constant([index], Type.i32), length + ) + axes.append(dim) + gather_indices_nodes.append(idx_value.output(0)) elif isinstance(index, builtins.slice): if index == builtins.slice(None): continue @@ -381,25 +391,18 @@ def count_unsqueeze_before(dim): "OpenVINO backend does not " "support multi-dimensional indexing" ) + if len(index_shape) == 0: + index = ov_opset.unsqueeze(index, zero_const).output(0) if index_type != Type.i32: index = ov_opset.convert(index, Type.i32).output(0) - shape_tensor = ov_opset.shape_of(data) + shape_tensor = ov_opset.shape_of(data, Type.i32) axis_i32 = ov_opset.constant([dim], dtype=Type.i32) - dim_size = ov_opset.gather( - shape_tensor, axis_i32, ov_opset.constant(0, Type.i32) - ) - dim_size = ov_opset.convert(dim_size, Type.i32) - zero = ov_opset.constant(0, Type.i32) - is_negative = ov_opset.less(index, zero) + dim_size = ov_opset.gather(shape_tensor, axis_i32, zero_const) + is_negative = ov_opset.less(index, zero_const) adjusted_index = ov_opset.add(index, dim_size) index = ov_opset.select( is_negative, adjusted_index, index ).output(0) - index_shape = index.get_partial_shape() - if len(index_shape) == 0: - index = ov_opset.unsqueeze( - index, ov_opset.constant(0, Type.i32) - ).output(0) gather_indices_nodes.append(index) else: raise ValueError( diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 638ad933e50..dbddfd2cba4 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -169,6 +169,172 @@ def test_convert_to_tensor(self): class CoreOpsCorrectnessTest(testing.TestCase): + def test_getitem(self): + self.np_tensor = np.arange(24).reshape(2, 3, 4) + self.tensor = ops.convert_to_tensor(self.np_tensor) + + t = self.tensor[1] + n = self.np_tensor[1] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1, 2, 3] + n = self.np_tensor[1, 2, 3] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1:2] + n = self.np_tensor[1:2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1:2, 2:3, 3:4] + n = self.np_tensor[1:2, 2:3, 3:4] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1:2, None] + n = self.np_tensor[1:2, None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1:2, 2:3, ...] + n = self.np_tensor[1:2, 2:3, ...] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1:2, ..., 3:4] + n = self.np_tensor[1:2, ..., 3:4] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[None, ..., 3:4, None] + n = self.np_tensor[None, ..., 3:4, None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1:2:None] + n = self.np_tensor[1:2:None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[:, 2] + n = self.np_tensor[:, 2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[None] + n = self.np_tensor[None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[None, None] + n = self.np_tensor[None, None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[...] + n = self.np_tensor[...] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[..., 1] + n = self.np_tensor[..., 1] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[..., 1, 2] + n = self.np_tensor[..., 1, 2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[..., -1, 2] + n = self.np_tensor[..., -1, 2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[..., -1:-2, 2] + n = self.np_tensor[..., -1:-2, 2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[..., None, None] + n = self.np_tensor[..., None, None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[None, ..., None] + n = self.np_tensor[None, ..., None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1, 2, None, ..., None] + n = self.np_tensor[1, 2, None, ..., None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[None, ..., 1, 2] + n = self.np_tensor[None, ..., 1, 2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1, None, 2] + n = self.np_tensor[1, None, 2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32)) + t = self.tensor[index_tensor] + n = self.np_tensor[ops.convert_to_numpy(index_tensor)] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32)) + t = self.tensor[index_tensor, 2, None] + n = self.np_tensor[ops.convert_to_numpy(index_tensor), 2, None] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + index_tensor = ops.convert_to_tensor(np.array(-2, dtype=np.int32)) + t = self.tensor[index_tensor, 1] + n = self.np_tensor[ops.convert_to_numpy(index_tensor), 1] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + index_tensor = ops.convert_to_tensor(np.array(-1, dtype=np.int32)) + t = self.tensor[-2, index_tensor] + n = self.np_tensor[-2, ops.convert_to_numpy(index_tensor)] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + # Negative indexing + t = self.tensor[-1] + n = self.np_tensor[-1] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[1, -1, -2] + n = self.np_tensor[1, -1, -2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + # Slicing with step + t = self.tensor[::2] + n = self.np_tensor[::2] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + # Mixed slices and integers + t = self.tensor[1, :, 1:4] + n = self.np_tensor[1, :, 1:4] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + + t = self.tensor[:, 1:2, 3] + n = self.np_tensor[:, 1:2, 3] + self.assertEqual(t.shape, n.shape) + self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n)) + def test_map(self): def f(x): return x**2