Skip to content

Commit c968548

Browse files
[OpenVINO Backend] support slice_update
1 parent 771b001 commit c968548

File tree

1 file changed

+71
-3
lines changed

1 file changed

+71
-3
lines changed

keras/src/backend/openvino/core.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,77 @@ def slice(inputs, start_indices, shape):
657657

658658

659659
def slice_update(inputs, start_indices, updates):
660-
raise NotImplementedError(
661-
"`slice_update` is not supported with openvino backend"
662-
)
660+
inputs = get_ov_output(inputs)
661+
processed_start_indices = []
662+
for idx in start_indices:
663+
val = get_ov_output(idx)
664+
val_type = val.get_element_type()
665+
if not val_type.is_integral():
666+
raise ValueError(
667+
"`slice` is not supported by OpenVINO backend "
668+
"for `start_indices` or `shape` with non-integer types"
669+
)
670+
if val_type != Type.i32:
671+
val = ov_opset.convert(val, Type.i32).output(0)
672+
if len(val.get_partial_shape()) == 0:
673+
val = ov_opset.unsqueeze(
674+
val, ov_opset.constant(0, Type.i32)
675+
).output(0)
676+
processed_start_indices.append(val)
677+
start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0)
678+
679+
rank = len(updates.shape)
680+
ranges = []
681+
for dim in updates.shape:
682+
r = ov_opset.range(
683+
ov_opset.constant(0, Type.i32),
684+
ov_opset.constant(dim, Type.i32),
685+
ov_opset.constant(1, Type.i32),
686+
output_type=Type.i32,
687+
)
688+
ranges.append(r)
689+
690+
broadcasted_ranges = []
691+
for i, r in enumerate(ranges):
692+
shape = [1] * rank
693+
shape[i] = updates.shape[i]
694+
r_reshaped = ov_opset.reshape(
695+
r, ov_opset.constant(shape, Type.i32), special_zero=False
696+
).output(0)
697+
target_shape = ov_opset.constant(list(updates.shape), Type.i32)
698+
r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0)
699+
broadcasted_ranges.append(r_broadcasted)
700+
701+
indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0)
702+
703+
num_updates = 1
704+
for dim in updates.shape:
705+
num_updates *= dim
706+
new_shape = ov_opset.constant([rank, num_updates], Type.i32)
707+
indices_reshaped = ov_opset.reshape(
708+
indices_stack, new_shape, special_zero=False
709+
).output(0)
710+
absolute_indices = ov_opset.transpose(
711+
indices_reshaped, ov_opset.constant([1, 0], Type.i32)
712+
).output(0)
713+
714+
start_indices_expanded = ov_opset.broadcast(
715+
start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32)
716+
).output(0)
717+
absolute_indices = ov_opset.add(
718+
absolute_indices, start_indices_expanded
719+
).output(0)
720+
721+
updates_tensor = get_ov_output(updates)
722+
updates_flat = ov_opset.reshape(
723+
updates_tensor,
724+
ov_opset.constant([num_updates], Type.i32),
725+
special_zero=False,
726+
).output(0)
727+
updated = ov_opset.scatter_nd_update(
728+
inputs, absolute_indices, updates_flat
729+
).output(0)
730+
return OpenVINOKerasTensor(updated)
663731

664732

665733
def while_loop(

0 commit comments

Comments
 (0)