Skip to content

Commit a0949a8

Browse files
[OpenVINO Backend] support comparison ops (>, <, <=, >=, ==, !=) for OpenVINOKerasTensor (#21348)
* [OpenVINO Backend] support comparison ops (>, <, <=, >=, ==, !=) * support comparison ops
1 parent 379be26 commit a0949a8

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

keras/src/backend/openvino/core.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,56 @@ def __rpow__(self, other):
257257
)
258258
return OpenVINOKerasTensor(ov_opset.power(other, first).output(0))
259259

260+
def __lt__(self, other):
261+
first = self.output
262+
other = get_ov_output(other)
263+
first, other = align_operand_types(
264+
first, other, "OpenVINOKerasTensor::__lt__"
265+
)
266+
return OpenVINOKerasTensor(ov_opset.less(first, other).output(0))
267+
268+
def __gt__(self, other):
269+
first = self.output
270+
other = get_ov_output(other)
271+
first, other = align_operand_types(
272+
first, other, "OpenVINOKerasTensor::__gt__"
273+
)
274+
return OpenVINOKerasTensor(ov_opset.greater(first, other).output(0))
275+
276+
def __le__(self, other):
277+
first = self.output
278+
other = get_ov_output(other)
279+
first, other = align_operand_types(
280+
first, other, "OpenVINOKerasTensor::__le__"
281+
)
282+
return OpenVINOKerasTensor(ov_opset.less_equal(first, other).output(0))
283+
284+
def __ge__(self, other):
285+
first = self.output
286+
other = get_ov_output(other)
287+
first, other = align_operand_types(
288+
first, other, "OpenVINOKerasTensor::__ge__"
289+
)
290+
return OpenVINOKerasTensor(
291+
ov_opset.greater_equal(first, other).output(0)
292+
)
293+
294+
def __eq__(self, other):
295+
first = self.output
296+
other = get_ov_output(other)
297+
first, other = align_operand_types(
298+
first, other, "OpenVINOKerasTensor::__eq__"
299+
)
300+
return OpenVINOKerasTensor(ov_opset.equal(first, other).output(0))
301+
302+
def __ne__(self, other):
303+
first = self.output
304+
other = get_ov_output(other)
305+
first, other = align_operand_types(
306+
first, other, "OpenVINOKerasTensor::__ne__"
307+
)
308+
return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0))
309+
260310
def __getitem__(self, indices):
261311
# now it has limited functionaly
262312
# and supports only a case with one integer index in indices

0 commit comments

Comments
 (0)