Skip to content

Commit 47c145c

Browse files
[OpenVINO Backend] update getitem
1 parent 771b001 commit 47c145c

File tree

2 files changed

+145
-20
lines changed

2 files changed

+145
-20
lines changed

keras/src/backend/openvino/core.py

Lines changed: 145 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import builtins
12
import contextlib
23
import warnings
34

@@ -308,30 +309,155 @@ def __ne__(self, other):
308309
return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0))
309310

310311
def __getitem__(self, indices):
311-
# now it has limited functionaly
312-
# and supports only a case with one integer index in indices
313-
# other indices must be None
314312
data = self.output
315-
axis = []
316-
gather_index = None
317-
if isinstance(indices, int):
313+
rank = len(data.get_partial_shape())
314+
axes, gather_indices_nodes = [], []
315+
slice_axes, slice_starts, slice_ends, slice_steps = [], [], [], []
316+
unsqueeze_axes = []
317+
318+
if not isinstance(indices, tuple):
318319
indices = (indices,)
319-
assert isinstance(indices, tuple), "only tuple is supported"
320+
321+
if any(i is Ellipsis for i in indices):
322+
ellipsis_pos = indices.index(Ellipsis)
323+
num_specified = sum(
324+
i is not Ellipsis and i is not None for i in indices
325+
)
326+
num_missing = rank - num_specified
327+
indices = (
328+
indices[:ellipsis_pos]
329+
+ (builtins.slice(None),) * num_missing
330+
+ indices[ellipsis_pos + 1 :]
331+
)
332+
333+
def count_unsqueeze_before(dim):
334+
return sum(1 for i in range(dim) if indices[i] is None)
335+
336+
partial_shape = data.get_partial_shape()
337+
320338
for dim, index in enumerate(indices):
321-
if isinstance(index, int):
322-
axis.append(dim)
323-
gather_index = ov_opset.constant(index, Type.i32)
339+
if isinstance(index, bool):
340+
raise ValueError(
341+
"OpenVINO backend does not support boolean indexing"
342+
)
343+
elif isinstance(index, int):
344+
actual_dim = dim - count_unsqueeze_before(dim)
345+
if not (0 <= actual_dim < rank):
346+
raise IndexError(
347+
f"Index {index} is out of bounds for "
348+
"axis {dim} with rank {rank}"
349+
)
350+
length = partial_shape[actual_dim].get_length()
351+
idx_value = index if index >= 0 else length + index
352+
axes.append(dim)
353+
gather_indices_nodes.append(
354+
ov_opset.constant([idx_value], Type.i32).output(0)
355+
)
356+
elif isinstance(index, builtins.slice):
357+
if index == builtins.slice(None):
358+
continue
359+
if index.step is not None and index.step < 0:
360+
raise ValueError("OpenVINO doesn't support negative steps")
361+
slice_axes.append(dim)
362+
slice_starts.append(0 if index.start is None else index.start)
363+
slice_ends.append(
364+
2**31 - 1 if index.stop is None else index.stop
365+
)
366+
slice_steps.append(1 if index.step is None else index.step)
367+
elif index is None:
368+
unsqueeze_axes.append(dim)
369+
elif isinstance(index, OpenVINOKerasTensor):
370+
index = get_ov_output(index)
371+
index_type = index.get_element_type()
372+
index_shape = index.get_partial_shape()
373+
if index_type == Type.boolean or not index_type.is_integral():
374+
raise ValueError(
375+
"OpenVINO backend does not "
376+
"support {index_type} indexing"
377+
)
378+
axes.append(dim)
379+
if len(index_shape) > 1:
380+
raise ValueError(
381+
"OpenVINO backend does not "
382+
"support multi-dimensional indexing"
383+
)
384+
if index_type != Type.i32:
385+
index = ov_opset.convert(index, Type.i32).output(0)
386+
shape_tensor = ov_opset.shape_of(data)
387+
axis_i32 = ov_opset.constant([dim], dtype=Type.i32)
388+
dim_size = ov_opset.gather(
389+
shape_tensor, axis_i32, ov_opset.constant(0, Type.i32)
390+
)
391+
dim_size = ov_opset.convert(dim_size, Type.i32)
392+
zero = ov_opset.constant(0, Type.i32)
393+
is_negative = ov_opset.less(index, zero)
394+
adjusted_index = ov_opset.add(index, dim_size)
395+
index = ov_opset.select(
396+
is_negative, adjusted_index, index
397+
).output(0)
398+
index_shape = index.get_partial_shape()
399+
if len(index_shape) == 0:
400+
index = ov_opset.unsqueeze(
401+
index, ov_opset.constant(0, Type.i32)
402+
).output(0)
403+
gather_indices_nodes.append(index)
324404
else:
325-
assert (
326-
index.start is None
327-
and index.stop is None
328-
and index.step is None
405+
raise ValueError(
406+
f"Unsupported index type {type(index)} "
407+
"in OpenVINOKerasTensor.__getitem__"
329408
)
330-
assert len(axis) == 1, "axis must contain one element"
331-
axis = ov_opset.constant(axis, Type.i32)
332-
return OpenVINOKerasTensor(
333-
ov_opset.gather(data, gather_index, axis).output(0)
334-
)
409+
410+
if slice_axes:
411+
step = ov_opset.constant(slice_steps, Type.i32).output(0)
412+
start = ov_opset.constant(slice_starts, Type.i32).output(0)
413+
stop = ov_opset.constant(slice_ends, Type.i32).output(0)
414+
adjusted_slice_axes = [
415+
ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)
416+
for ax in slice_axes
417+
]
418+
axes_const = ov_opset.constant(
419+
adjusted_slice_axes, Type.i32
420+
).output(0)
421+
data = ov_opset.slice(data, start, stop, step, axes_const).output(0)
422+
423+
if axes:
424+
gather_indices_const = (
425+
gather_indices_nodes[0]
426+
if len(gather_indices_nodes) == 1
427+
else ov_opset.concat(gather_indices_nodes, axis=0).output(0)
428+
)
429+
adjusted_axes = [
430+
ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)
431+
for ax in axes
432+
]
433+
if len(axes) == 1:
434+
data = ov_opset.gather(
435+
data, gather_indices_const, adjusted_axes[0]
436+
).output(0)
437+
data = ov_opset.squeeze(data, adjusted_axes[0]).output(0)
438+
else:
439+
rank = len(data.get_partial_shape())
440+
remaining_axes = [
441+
i for i in range(rank) if i not in adjusted_axes
442+
]
443+
perm = ov_opset.constant(
444+
adjusted_axes + remaining_axes, Type.i32
445+
)
446+
data = ov_opset.transpose(data, perm).output(0)
447+
data = ov_opset.gather_nd(data, gather_indices_const).output(0)
448+
449+
if unsqueeze_axes:
450+
adjusted_unsqueeze = []
451+
for ax in unsqueeze_axes:
452+
ax -= sum(1 for s in axes if s < ax)
453+
ax -= sum(1 for s in slice_axes if s < ax)
454+
adjusted_unsqueeze.append(ax)
455+
unsqueeze_const = ov_opset.constant(
456+
adjusted_unsqueeze, Type.i32
457+
).output(0)
458+
data = ov_opset.unsqueeze(data, unsqueeze_const).output(0)
459+
460+
return OpenVINOKerasTensor(data)
335461

336462
def __len__(self):
337463
ov_output = self.output

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
keras/src/activations
22
keras/src/backend/common/dtypes_test.py
3-
keras/src/backend/common/variables_test.py
43
keras/src/callbacks/early_stopping_test.py
54
keras/src/dtype_policies/dtype_policy_map_test.py
65
keras/src/layers/attention

0 commit comments

Comments
 (0)