Skip to content

Commit 379be26

Browse files
[OpenVINO Backend] support where operation (#21354)
1 parent 58313ea commit 379be26

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ NumpyDtypeTest::test_unravel
6565
NumpyDtypeTest::test_var
6666
NumpyDtypeTest::test_vdot
6767
NumpyDtypeTest::test_vstack
68-
NumpyDtypeTest::test_where
6968
NumpyDtypeTest::test_clip_bool
7069
NumpyDtypeTest::test_square_bool
7170
HistogramTest
@@ -147,7 +146,6 @@ NumpyTwoInputOpsCorrectnessTest::test_quantile
147146
NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
148147
NumpyTwoInputOpsCorrectnessTest::test_tensordot
149148
NumpyTwoInputOpsCorrectnessTest::test_vdot
150-
NumpyTwoInputOpsCorrectnessTest::test_where
151149
NumpyOneInputOpsDynamicShapeTest::test_angle
152150
NumpyOneInputOpsDynamicShapeTest::test_bartlett
153151
NumpyOneInputOpsDynamicShapeTest::test_blackman

keras/src/backend/openvino/numpy.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,32 @@ def vectorize(pyfunc, *, excluded=None, signature=None):
15101510

15111511

15121512
def where(condition, x1=None, x2=None):
1513-
raise NotImplementedError("`where` is not supported with openvino backend")
1513+
condition = get_ov_output(condition)
1514+
if x1 is None and x2 is None:
1515+
nonzero_indices = ov_opset.non_zero(condition)
1516+
return OpenVINOKerasTensor(nonzero_indices.output(0))
1517+
if x1 is None:
1518+
return OpenVINOKerasTensor(condition)
1519+
if x2 is None:
1520+
raise ValueError("x2 must be provided if x1 is specified.")
1521+
1522+
def cast_literal_like_tensor(literal, x):
1523+
ov_type = get_ov_output(x).get_element_type()
1524+
is_bool = ov_type == Type.boolean
1525+
is_float_to_int = isinstance(literal, float) and ov_type.is_integral()
1526+
if is_bool or is_float_to_int:
1527+
return get_ov_output(literal), get_ov_output(x)
1528+
return get_ov_output(literal, ov_type), get_ov_output(x)
1529+
1530+
if isinstance(x1, (int, float)):
1531+
x1, x2 = cast_literal_like_tensor(x1, x2)
1532+
elif isinstance(x2, (int, float)):
1533+
x2, x1 = cast_literal_like_tensor(x2, x1)
1534+
else:
1535+
x1 = get_ov_output(x1)
1536+
x2 = get_ov_output(x2)
1537+
x1, x2 = _align_operand_types(x1, x2, "select()")
1538+
return OpenVINOKerasTensor(ov_opset.select(condition, x1, x2).output(0))
15141539

15151540

15161541
def divide(x1, x2):

0 commit comments

Comments
 (0)