Skip to content

Commit 771b001

Browse files
authored
Make dtype a static parameter of OnesLike and ZerosLike. (#21358)
The `dtype` parameter is never a tensor and should therefore be a static parameter in the `OnesLike` and `ZerosLike` operations.
1 parent a0949a8 commit 771b001

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

keras/src/ops/numpy.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4579,13 +4579,19 @@ def not_equal(x1, x2):
45794579

45804580

45814581
class OnesLike(Operation):
4582-
def call(self, x, dtype=None):
4583-
return backend.numpy.ones_like(x, dtype=dtype)
4582+
def __init__(self, dtype=None):
4583+
super().__init__()
4584+
self.dtype = (
4585+
backend.standardize_dtype(dtype) if dtype is not None else None
4586+
)
45844587

4585-
def compute_output_spec(self, x, dtype=None):
4586-
if dtype is None:
4587-
dtype = x.dtype
4588-
return KerasTensor(x.shape, dtype=dtype)
4588+
def call(self, x):
4589+
return backend.numpy.ones_like(x, dtype=self.dtype)
4590+
4591+
def compute_output_spec(self, x):
4592+
dtype = x.dtype if self.dtype is None else self.dtype
4593+
sparse = getattr(x, "sparse", False)
4594+
return KerasTensor(x.shape, dtype=dtype, sparse=sparse)
45894595

45904596

45914597
@keras_export(["keras.ops.ones_like", "keras.ops.numpy.ones_like"])
@@ -4600,18 +4606,24 @@ def ones_like(x, dtype=None):
46004606
A tensor of ones with the same shape and type as `x`.
46014607
"""
46024608
if any_symbolic_tensors((x,)):
4603-
return OnesLike().symbolic_call(x, dtype=dtype)
4609+
return OnesLike(dtype=dtype).symbolic_call(x)
46044610
return backend.numpy.ones_like(x, dtype=dtype)
46054611

46064612

46074613
class ZerosLike(Operation):
4608-
def call(self, x, dtype=None):
4609-
return backend.numpy.zeros_like(x, dtype=dtype)
4614+
def __init__(self, dtype=None):
4615+
super().__init__()
4616+
self.dtype = (
4617+
backend.standardize_dtype(dtype) if dtype is not None else None
4618+
)
4619+
4620+
def call(self, x):
4621+
return backend.numpy.zeros_like(x, dtype=self.dtype)
46104622

46114623
def compute_output_spec(self, x, dtype=None):
4612-
if dtype is None:
4613-
dtype = x.dtype
4614-
return KerasTensor(x.shape, dtype=dtype)
4624+
dtype = x.dtype if self.dtype is None else self.dtype
4625+
sparse = getattr(x, "sparse", False)
4626+
return KerasTensor(x.shape, dtype=dtype, sparse=sparse)
46154627

46164628

46174629
@keras_export(
@@ -4631,7 +4643,7 @@ def zeros_like(x, dtype=None):
46314643
A tensor of zeros with the same shape and type as `x`.
46324644
"""
46334645
if any_symbolic_tensors((x,)):
4634-
return ZerosLike().symbolic_call(x, dtype=dtype)
4646+
return ZerosLike(dtype=dtype).symbolic_call(x)
46354647
return backend.numpy.zeros_like(x, dtype=dtype)
46364648

46374649

0 commit comments

Comments
 (0)