Skip to content

Commit d16424e

Browse files
[OpenVINO Backend] update __getitem__ (#21359)
* [OpenVINO Backend] update getitem * [OpenVINO backend] update getitem
1 parent fa3e8df commit d16424e

File tree

3 files changed

+314
-20
lines changed

3 files changed

+314
-20
lines changed

keras/src/backend/openvino/core.py

Lines changed: 148 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import builtins
12
import contextlib
23
import warnings
34

@@ -308,30 +309,158 @@ def __ne__(self, other):
308309
return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0))
309310

310311
def __getitem__(self, indices):
311-
# now it has limited functionaly
312-
# and supports only a case with one integer index in indices
313-
# other indices must be None
314312
data = self.output
315-
axis = []
316-
gather_index = None
317-
if isinstance(indices, int):
313+
rank = len(data.get_partial_shape())
314+
axes, gather_indices_nodes = [], []
315+
slice_axes, slice_starts, slice_ends, slice_steps = [], [], [], []
316+
unsqueeze_axes = []
317+
318+
if not isinstance(indices, tuple):
318319
indices = (indices,)
319-
assert isinstance(indices, tuple), "only tuple is supported"
320+
321+
if any(i is Ellipsis for i in indices):
322+
ellipsis_pos = indices.index(Ellipsis)
323+
num_specified = sum(
324+
i is not Ellipsis and i is not None for i in indices
325+
)
326+
num_missing = rank - num_specified
327+
indices = (
328+
indices[:ellipsis_pos]
329+
+ (builtins.slice(None),) * num_missing
330+
+ indices[ellipsis_pos + 1 :]
331+
)
332+
333+
def count_unsqueeze_before(dim):
334+
return sum(1 for i in range(dim) if indices[i] is None)
335+
336+
partial_shape = ov_opset.shape_of(data, Type.i32)
337+
zero_const = ov_opset.constant(0, Type.i32)
338+
320339
for dim, index in enumerate(indices):
321-
if isinstance(index, int):
322-
axis.append(dim)
323-
gather_index = ov_opset.constant(index, Type.i32)
340+
if isinstance(index, bool):
341+
raise ValueError(
342+
"OpenVINO backend does not support boolean indexing"
343+
)
344+
elif isinstance(index, (int, np.integer)):
345+
if isinstance(index, np.integer):
346+
index = int(index)
347+
actual_dim = dim - count_unsqueeze_before(dim)
348+
if not (0 <= actual_dim < rank):
349+
raise IndexError(
350+
f"Index {index} is out of bounds for "
351+
"axis {dim} with rank {rank}"
352+
)
353+
length = ov_opset.gather(
354+
partial_shape,
355+
ov_opset.constant([actual_dim], Type.i32),
356+
zero_const,
357+
)
358+
if index >= 0:
359+
idx_value = ov_opset.constant([index], Type.i32)
360+
else:
361+
idx_value = ov_opset.add(
362+
ov_opset.constant([index], Type.i32), length
363+
)
364+
axes.append(dim)
365+
gather_indices_nodes.append(idx_value.output(0))
366+
elif isinstance(index, builtins.slice):
367+
if index == builtins.slice(None):
368+
continue
369+
if index.step is not None and index.step < 0:
370+
raise ValueError("OpenVINO doesn't support negative steps")
371+
slice_axes.append(dim)
372+
slice_starts.append(0 if index.start is None else index.start)
373+
slice_ends.append(
374+
2**31 - 1 if index.stop is None else index.stop
375+
)
376+
slice_steps.append(1 if index.step is None else index.step)
377+
elif index is None:
378+
unsqueeze_axes.append(dim)
379+
elif isinstance(index, OpenVINOKerasTensor):
380+
index = get_ov_output(index)
381+
index_type = index.get_element_type()
382+
index_shape = index.get_partial_shape()
383+
if index_type == Type.boolean or not index_type.is_integral():
384+
raise ValueError(
385+
"OpenVINO backend does not "
386+
"support {index_type} indexing"
387+
)
388+
axes.append(dim)
389+
if len(index_shape) > 1:
390+
raise ValueError(
391+
"OpenVINO backend does not "
392+
"support multi-dimensional indexing"
393+
)
394+
if len(index_shape) == 0:
395+
index = ov_opset.unsqueeze(index, zero_const).output(0)
396+
if index_type != Type.i32:
397+
index = ov_opset.convert(index, Type.i32).output(0)
398+
shape_tensor = ov_opset.shape_of(data, Type.i32)
399+
axis_i32 = ov_opset.constant([dim], dtype=Type.i32)
400+
dim_size = ov_opset.gather(shape_tensor, axis_i32, zero_const)
401+
is_negative = ov_opset.less(index, zero_const)
402+
adjusted_index = ov_opset.add(index, dim_size)
403+
index = ov_opset.select(
404+
is_negative, adjusted_index, index
405+
).output(0)
406+
gather_indices_nodes.append(index)
324407
else:
325-
assert (
326-
index.start is None
327-
and index.stop is None
328-
and index.step is None
408+
raise ValueError(
409+
f"Unsupported index type {type(index)} "
410+
"in OpenVINOKerasTensor.__getitem__"
329411
)
330-
assert len(axis) == 1, "axis must contain one element"
331-
axis = ov_opset.constant(axis, Type.i32)
332-
return OpenVINOKerasTensor(
333-
ov_opset.gather(data, gather_index, axis).output(0)
334-
)
412+
413+
if slice_axes:
414+
step = ov_opset.constant(slice_steps, Type.i32).output(0)
415+
start = ov_opset.constant(slice_starts, Type.i32).output(0)
416+
stop = ov_opset.constant(slice_ends, Type.i32).output(0)
417+
adjusted_slice_axes = [
418+
ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)
419+
for ax in slice_axes
420+
]
421+
axes_const = ov_opset.constant(
422+
adjusted_slice_axes, Type.i32
423+
).output(0)
424+
data = ov_opset.slice(data, start, stop, step, axes_const).output(0)
425+
426+
if axes:
427+
gather_indices_const = (
428+
gather_indices_nodes[0]
429+
if len(gather_indices_nodes) == 1
430+
else ov_opset.concat(gather_indices_nodes, axis=0).output(0)
431+
)
432+
adjusted_axes = [
433+
ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)
434+
for ax in axes
435+
]
436+
if len(axes) == 1:
437+
data = ov_opset.gather(
438+
data, gather_indices_const, adjusted_axes[0]
439+
).output(0)
440+
data = ov_opset.squeeze(data, adjusted_axes[0]).output(0)
441+
else:
442+
rank = len(data.get_partial_shape())
443+
remaining_axes = [
444+
i for i in range(rank) if i not in adjusted_axes
445+
]
446+
perm = ov_opset.constant(
447+
adjusted_axes + remaining_axes, Type.i32
448+
)
449+
data = ov_opset.transpose(data, perm).output(0)
450+
data = ov_opset.gather_nd(data, gather_indices_const).output(0)
451+
452+
if unsqueeze_axes:
453+
adjusted_unsqueeze = []
454+
for ax in unsqueeze_axes:
455+
ax -= sum(1 for s in axes if s < ax)
456+
ax -= sum(1 for s in slice_axes if s < ax)
457+
adjusted_unsqueeze.append(ax)
458+
unsqueeze_const = ov_opset.constant(
459+
adjusted_unsqueeze, Type.i32
460+
).output(0)
461+
data = ov_opset.unsqueeze(data, unsqueeze_const).output(0)
462+
463+
return OpenVINOKerasTensor(data)
335464

336465
def __len__(self):
337466
ov_output = self.output

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
keras/src/activations
22
keras/src/backend/common/dtypes_test.py
3-
keras/src/backend/common/variables_test.py
43
keras/src/callbacks/early_stopping_test.py
54
keras/src/dtype_policies/dtype_policy_map_test.py
65
keras/src/layers/attention

keras/src/ops/core_test.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,172 @@ def test_convert_to_tensor(self):
169169

170170

171171
class CoreOpsCorrectnessTest(testing.TestCase):
172+
def test_getitem(self):
173+
self.np_tensor = np.arange(24).reshape(2, 3, 4)
174+
self.tensor = ops.convert_to_tensor(self.np_tensor)
175+
176+
t = self.tensor[1]
177+
n = self.np_tensor[1]
178+
self.assertEqual(t.shape, n.shape)
179+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
180+
181+
t = self.tensor[1, 2, 3]
182+
n = self.np_tensor[1, 2, 3]
183+
self.assertEqual(t.shape, n.shape)
184+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
185+
186+
t = self.tensor[1:2]
187+
n = self.np_tensor[1:2]
188+
self.assertEqual(t.shape, n.shape)
189+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
190+
191+
t = self.tensor[1:2, 2:3, 3:4]
192+
n = self.np_tensor[1:2, 2:3, 3:4]
193+
self.assertEqual(t.shape, n.shape)
194+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
195+
196+
t = self.tensor[1:2, None]
197+
n = self.np_tensor[1:2, None]
198+
self.assertEqual(t.shape, n.shape)
199+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
200+
201+
t = self.tensor[1:2, 2:3, ...]
202+
n = self.np_tensor[1:2, 2:3, ...]
203+
self.assertEqual(t.shape, n.shape)
204+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
205+
206+
t = self.tensor[1:2, ..., 3:4]
207+
n = self.np_tensor[1:2, ..., 3:4]
208+
self.assertEqual(t.shape, n.shape)
209+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
210+
211+
t = self.tensor[None, ..., 3:4, None]
212+
n = self.np_tensor[None, ..., 3:4, None]
213+
self.assertEqual(t.shape, n.shape)
214+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
215+
216+
t = self.tensor[1:2:None]
217+
n = self.np_tensor[1:2:None]
218+
self.assertEqual(t.shape, n.shape)
219+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
220+
221+
t = self.tensor[:, 2]
222+
n = self.np_tensor[:, 2]
223+
self.assertEqual(t.shape, n.shape)
224+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
225+
226+
t = self.tensor[None]
227+
n = self.np_tensor[None]
228+
self.assertEqual(t.shape, n.shape)
229+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
230+
231+
t = self.tensor[None, None]
232+
n = self.np_tensor[None, None]
233+
self.assertEqual(t.shape, n.shape)
234+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
235+
236+
t = self.tensor[...]
237+
n = self.np_tensor[...]
238+
self.assertEqual(t.shape, n.shape)
239+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
240+
241+
t = self.tensor[..., 1]
242+
n = self.np_tensor[..., 1]
243+
self.assertEqual(t.shape, n.shape)
244+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
245+
246+
t = self.tensor[..., 1, 2]
247+
n = self.np_tensor[..., 1, 2]
248+
self.assertEqual(t.shape, n.shape)
249+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
250+
251+
t = self.tensor[..., -1, 2]
252+
n = self.np_tensor[..., -1, 2]
253+
self.assertEqual(t.shape, n.shape)
254+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
255+
256+
t = self.tensor[..., -1:-2, 2]
257+
n = self.np_tensor[..., -1:-2, 2]
258+
self.assertEqual(t.shape, n.shape)
259+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
260+
261+
t = self.tensor[..., None, None]
262+
n = self.np_tensor[..., None, None]
263+
self.assertEqual(t.shape, n.shape)
264+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
265+
266+
t = self.tensor[None, ..., None]
267+
n = self.np_tensor[None, ..., None]
268+
self.assertEqual(t.shape, n.shape)
269+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
270+
271+
t = self.tensor[1, 2, None, ..., None]
272+
n = self.np_tensor[1, 2, None, ..., None]
273+
self.assertEqual(t.shape, n.shape)
274+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
275+
276+
t = self.tensor[None, ..., 1, 2]
277+
n = self.np_tensor[None, ..., 1, 2]
278+
self.assertEqual(t.shape, n.shape)
279+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
280+
281+
t = self.tensor[1, None, 2]
282+
n = self.np_tensor[1, None, 2]
283+
self.assertEqual(t.shape, n.shape)
284+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
285+
286+
index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32))
287+
t = self.tensor[index_tensor]
288+
n = self.np_tensor[ops.convert_to_numpy(index_tensor)]
289+
self.assertEqual(t.shape, n.shape)
290+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
291+
292+
index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32))
293+
t = self.tensor[index_tensor, 2, None]
294+
n = self.np_tensor[ops.convert_to_numpy(index_tensor), 2, None]
295+
self.assertEqual(t.shape, n.shape)
296+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
297+
298+
index_tensor = ops.convert_to_tensor(np.array(-2, dtype=np.int32))
299+
t = self.tensor[index_tensor, 1]
300+
n = self.np_tensor[ops.convert_to_numpy(index_tensor), 1]
301+
self.assertEqual(t.shape, n.shape)
302+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
303+
304+
index_tensor = ops.convert_to_tensor(np.array(-1, dtype=np.int32))
305+
t = self.tensor[-2, index_tensor]
306+
n = self.np_tensor[-2, ops.convert_to_numpy(index_tensor)]
307+
self.assertEqual(t.shape, n.shape)
308+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
309+
310+
# Negative indexing
311+
t = self.tensor[-1]
312+
n = self.np_tensor[-1]
313+
self.assertEqual(t.shape, n.shape)
314+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
315+
316+
t = self.tensor[1, -1, -2]
317+
n = self.np_tensor[1, -1, -2]
318+
self.assertEqual(t.shape, n.shape)
319+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
320+
321+
# Slicing with step
322+
t = self.tensor[::2]
323+
n = self.np_tensor[::2]
324+
self.assertEqual(t.shape, n.shape)
325+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
326+
327+
# Mixed slices and integers
328+
t = self.tensor[1, :, 1:4]
329+
n = self.np_tensor[1, :, 1:4]
330+
self.assertEqual(t.shape, n.shape)
331+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
332+
333+
t = self.tensor[:, 1:2, 3]
334+
n = self.np_tensor[:, 1:2, 3]
335+
self.assertEqual(t.shape, n.shape)
336+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
337+
172338
def test_map(self):
173339
def f(x):
174340
return x**2

0 commit comments

Comments
 (0)