Skip to content

Commit 88938fd

Browse files
[OpenVINO Backend] update __getitem__
1 parent 771b001 commit 88938fd

File tree

2 files changed

+59
-15
lines changed

2 files changed

+59
-15
lines changed

keras/src/backend/openvino/core.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import builtins
12
import contextlib
23
import warnings
34

@@ -308,30 +309,74 @@ def __ne__(self, other):
308309
return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0))
309310

310311
def __getitem__(self, indices):
311-
# now it has limited functionaly
312-
# and supports only a case with one integer index in indices
313-
# other indices must be None
314312
data = self.output
315-
axis = []
316-
gather_index = None
317-
if isinstance(indices, int):
313+
axes, gather_indices = [], []
314+
slice_axes, slice_starts, slice_ends, slice_steps = [], [], [], []
315+
unsqueeze_axes, ellipsis_index = [], None
316+
317+
if not isinstance(indices, tuple):
318318
indices = (indices,)
319-
assert isinstance(indices, tuple), "only tuple is supported"
319+
320320
for dim, index in enumerate(indices):
321321
if isinstance(index, int):
322-
axis.append(dim)
323-
gather_index = ov_opset.constant(index, Type.i32)
322+
axes.append(dim)
323+
gather_indices.append(index)
324+
elif isinstance(index, builtins.slice):
325+
slice_axes.append(dim)
326+
slice_starts.append(0 if index.start is None else index.start)
327+
slice_ends.append(
328+
2**31 - 1 if index.stop is None else index.stop
329+
)
330+
slice_steps.append(1 if index.step is None else index.step)
331+
elif index is Ellipsis:
332+
ellipsis_index = dim
333+
elif index is None:
334+
unsqueeze_axes.append(dim)
335+
elif isinstance(index, OpenVINOKerasTensor):
336+
axes.append(dim)
337+
gather_indices.append(convert_to_numpy(index))
324338
else:
325339
assert (
326340
index.start is None
327341
and index.stop is None
328342
and index.step is None
329343
)
330-
assert len(axis) == 1, "axis must contain one element"
331-
axis = ov_opset.constant(axis, Type.i32)
332-
return OpenVINOKerasTensor(
333-
ov_opset.gather(data, gather_index, axis).output(0)
334-
)
344+
345+
if slice_axes:
346+
step = ov_opset.constant(slice_steps, Type.i32).output(0)
347+
start = ov_opset.constant(slice_starts, Type.i32).output(0)
348+
stop = ov_opset.constant(slice_ends, Type.i32).output(0)
349+
axes_const = ov_opset.constant(slice_axes, Type.i32).output(0)
350+
data = ov_opset.slice(data, start, stop, step, axes_const).output(0)
351+
352+
if axes:
353+
adjusted_axes = [
354+
ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)
355+
for ax in axes
356+
]
357+
rank = len(data.get_partial_shape())
358+
remaining_axes = [i for i in range(rank) if i not in adjusted_axes]
359+
perm = ov_opset.constant(adjusted_axes + remaining_axes, Type.i32)
360+
gather_indices_const = ov_opset.constant(gather_indices, Type.i32)
361+
data = ov_opset.transpose(data, perm).output(0)
362+
data = ov_opset.gather_nd(data, gather_indices_const).output(0)
363+
364+
if unsqueeze_axes:
365+
expanded_rank = data.get_partial_shape().rank.get_length() + len(
366+
unsqueeze_axes
367+
)
368+
adjusted_unsqueeze = []
369+
for ax in unsqueeze_axes:
370+
ax -= sum(1 for s in axes if s < ax)
371+
if ellipsis_index is not None and ax > ellipsis_index:
372+
ax += expanded_rank - len(indices)
373+
adjusted_unsqueeze.append(ax)
374+
unsqueeze_const = ov_opset.constant(
375+
adjusted_unsqueeze, Type.i32
376+
).output(0)
377+
data = ov_opset.unsqueeze(data, unsqueeze_const).output(0)
378+
379+
return OpenVINOKerasTensor(data)
335380

336381
def __len__(self):
337382
ov_output = self.output

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
keras/src/activations
22
keras/src/backend/common/dtypes_test.py
3-
keras/src/backend/common/variables_test.py
43
keras/src/callbacks/early_stopping_test.py
54
keras/src/dtype_policies/dtype_policy_map_test.py
65
keras/src/layers/attention

0 commit comments

Comments
 (0)