Skip to content

Commit 08daf79

Browse files
[OpenVNIO Backend] update __getitem__
1 parent 47c145c commit 08daf79

File tree

2 files changed

+185
-18
lines changed

2 files changed

+185
-18
lines changed

keras/src/backend/openvino/core.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,8 @@ def __getitem__(self, indices):
333333
def count_unsqueeze_before(dim):
334334
return sum(1 for i in range(dim) if indices[i] is None)
335335

336-
partial_shape = data.get_partial_shape()
336+
partial_shape = ov_opset.shape_of(data, Type.i32)
337+
zero_const = ov_opset.constant(0, Type.i32)
337338

338339
for dim, index in enumerate(indices):
339340
if isinstance(index, bool):
@@ -347,12 +348,19 @@ def count_unsqueeze_before(dim):
347348
f"Index {index} is out of bounds for "
348349
"axis {dim} with rank {rank}"
349350
)
350-
length = partial_shape[actual_dim].get_length()
351-
idx_value = index if index >= 0 else length + index
352-
axes.append(dim)
353-
gather_indices_nodes.append(
354-
ov_opset.constant([idx_value], Type.i32).output(0)
351+
length = ov_opset.gather(
352+
partial_shape,
353+
ov_opset.constant([actual_dim], Type.i32),
354+
zero_const,
355355
)
356+
if index >= 0:
357+
idx_value = ov_opset.constant([index], Type.i32)
358+
else:
359+
idx_value = ov_opset.add(
360+
ov_opset.constant([index], Type.i32), length
361+
)
362+
axes.append(dim)
363+
gather_indices_nodes.append(idx_value.output(0))
356364
elif isinstance(index, builtins.slice):
357365
if index == builtins.slice(None):
358366
continue
@@ -381,25 +389,18 @@ def count_unsqueeze_before(dim):
381389
"OpenVINO backend does not "
382390
"support multi-dimensional indexing"
383391
)
392+
if len(index_shape) == 0:
393+
index = ov_opset.unsqueeze(index, zero_const).output(0)
384394
if index_type != Type.i32:
385395
index = ov_opset.convert(index, Type.i32).output(0)
386-
shape_tensor = ov_opset.shape_of(data)
396+
shape_tensor = ov_opset.shape_of(data, Type.i32)
387397
axis_i32 = ov_opset.constant([dim], dtype=Type.i32)
388-
dim_size = ov_opset.gather(
389-
shape_tensor, axis_i32, ov_opset.constant(0, Type.i32)
390-
)
391-
dim_size = ov_opset.convert(dim_size, Type.i32)
392-
zero = ov_opset.constant(0, Type.i32)
393-
is_negative = ov_opset.less(index, zero)
398+
dim_size = ov_opset.gather(shape_tensor, axis_i32, zero_const)
399+
is_negative = ov_opset.less(index, zero_const)
394400
adjusted_index = ov_opset.add(index, dim_size)
395401
index = ov_opset.select(
396402
is_negative, adjusted_index, index
397403
).output(0)
398-
index_shape = index.get_partial_shape()
399-
if len(index_shape) == 0:
400-
index = ov_opset.unsqueeze(
401-
index, ov_opset.constant(0, Type.i32)
402-
).output(0)
403404
gather_indices_nodes.append(index)
404405
else:
405406
raise ValueError(

keras/src/ops/core_test.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,172 @@ def test_convert_to_tensor(self):
169169

170170

171171
class CoreOpsCorrectnessTest(testing.TestCase):
172+
def test_getitem(self):
173+
self.np_tensor = np.arange(24).reshape(2, 3, 4)
174+
self.tensor = ops.convert_to_tensor(self.np_tensor)
175+
176+
t = self.tensor[1]
177+
n = self.np_tensor[1]
178+
self.assertEqual(t.shape, n.shape)
179+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
180+
181+
t = self.tensor[1, 2, 3]
182+
n = self.np_tensor[1, 2, 3]
183+
self.assertEqual(t.shape, n.shape)
184+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
185+
186+
t = self.tensor[1:2]
187+
n = self.np_tensor[1:2]
188+
self.assertEqual(t.shape, n.shape)
189+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
190+
191+
t = self.tensor[1:2, 2:3, 3:4]
192+
n = self.np_tensor[1:2, 2:3, 3:4]
193+
self.assertEqual(t.shape, n.shape)
194+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
195+
196+
t = self.tensor[1:2, None]
197+
n = self.np_tensor[1:2, None]
198+
self.assertEqual(t.shape, n.shape)
199+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
200+
201+
t = self.tensor[1:2, 2:3, ...]
202+
n = self.np_tensor[1:2, 2:3, ...]
203+
self.assertEqual(t.shape, n.shape)
204+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
205+
206+
t = self.tensor[1:2, ..., 3:4]
207+
n = self.np_tensor[1:2, ..., 3:4]
208+
self.assertEqual(t.shape, n.shape)
209+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
210+
211+
t = self.tensor[None, ..., 3:4, None]
212+
n = self.np_tensor[None, ..., 3:4, None]
213+
self.assertEqual(t.shape, n.shape)
214+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
215+
216+
t = self.tensor[1:2:None]
217+
n = self.np_tensor[1:2:None]
218+
self.assertEqual(t.shape, n.shape)
219+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
220+
221+
t = self.tensor[:, 2]
222+
n = self.np_tensor[:, 2]
223+
self.assertEqual(t.shape, n.shape)
224+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
225+
226+
t = self.tensor[None]
227+
n = self.np_tensor[None]
228+
self.assertEqual(t.shape, n.shape)
229+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
230+
231+
t = self.tensor[None, None]
232+
n = self.np_tensor[None, None]
233+
self.assertEqual(t.shape, n.shape)
234+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
235+
236+
t = self.tensor[...]
237+
n = self.np_tensor[...]
238+
self.assertEqual(t.shape, n.shape)
239+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
240+
241+
t = self.tensor[..., 1]
242+
n = self.np_tensor[..., 1]
243+
self.assertEqual(t.shape, n.shape)
244+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
245+
246+
t = self.tensor[..., 1, 2]
247+
n = self.np_tensor[..., 1, 2]
248+
self.assertEqual(t.shape, n.shape)
249+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
250+
251+
t = self.tensor[..., -1, 2]
252+
n = self.np_tensor[..., -1, 2]
253+
self.assertEqual(t.shape, n.shape)
254+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
255+
256+
t = self.tensor[..., -1:-2, 2]
257+
n = self.np_tensor[..., -1:-2, 2]
258+
self.assertEqual(t.shape, n.shape)
259+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
260+
261+
t = self.tensor[..., None, None]
262+
n = self.np_tensor[..., None, None]
263+
self.assertEqual(t.shape, n.shape)
264+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
265+
266+
t = self.tensor[None, ..., None]
267+
n = self.np_tensor[None, ..., None]
268+
self.assertEqual(t.shape, n.shape)
269+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
270+
271+
t = self.tensor[1, 2, None, ..., None]
272+
n = self.np_tensor[1, 2, None, ..., None]
273+
self.assertEqual(t.shape, n.shape)
274+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
275+
276+
t = self.tensor[None, ..., 1, 2]
277+
n = self.np_tensor[None, ..., 1, 2]
278+
self.assertEqual(t.shape, n.shape)
279+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
280+
281+
t = self.tensor[1, None, 2]
282+
n = self.np_tensor[1, None, 2]
283+
self.assertEqual(t.shape, n.shape)
284+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
285+
286+
index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32))
287+
t = self.tensor[index_tensor]
288+
n = self.np_tensor[ops.convert_to_numpy(index_tensor)]
289+
self.assertEqual(t.shape, n.shape)
290+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
291+
292+
index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32))
293+
t = self.tensor[index_tensor, 2, None]
294+
n = self.np_tensor[ops.convert_to_numpy(index_tensor), 2, None]
295+
self.assertEqual(t.shape, n.shape)
296+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
297+
298+
index_tensor = ops.convert_to_tensor(np.array(-2, dtype=np.int32))
299+
t = self.tensor[index_tensor, 1]
300+
n = self.np_tensor[ops.convert_to_numpy(index_tensor), 1]
301+
self.assertEqual(t.shape, n.shape)
302+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
303+
304+
index_tensor = ops.convert_to_tensor(np.array(-1, dtype=np.int32))
305+
t = self.tensor[-2, index_tensor]
306+
n = self.np_tensor[-2, ops.convert_to_numpy(index_tensor)]
307+
self.assertEqual(t.shape, n.shape)
308+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
309+
310+
# Negative indexing
311+
t = self.tensor[-1]
312+
n = self.np_tensor[-1]
313+
self.assertEqual(t.shape, n.shape)
314+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
315+
316+
t = self.tensor[1, -1, -2]
317+
n = self.np_tensor[1, -1, -2]
318+
self.assertEqual(t.shape, n.shape)
319+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
320+
321+
# Slicing with step
322+
t = self.tensor[::2]
323+
n = self.np_tensor[::2]
324+
self.assertEqual(t.shape, n.shape)
325+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
326+
327+
# Mixed slices and integers
328+
t = self.tensor[1, :, 1:4]
329+
n = self.np_tensor[1, :, 1:4]
330+
self.assertEqual(t.shape, n.shape)
331+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
332+
333+
t = self.tensor[:, 1:2, 3]
334+
n = self.np_tensor[:, 1:2, 3]
335+
self.assertEqual(t.shape, n.shape)
336+
self.assertTrue(np.array_equal(ops.convert_to_numpy(t), n))
337+
172338
def test_map(self):
173339
def f(x):
174340
return x**2

0 commit comments

Comments
 (0)