@@ -4579,13 +4579,19 @@ def not_equal(x1, x2):
4579
4579
4580
4580
4581
4581
class OnesLike (Operation ):
4582
- def call (self , x , dtype = None ):
4583
- return backend .numpy .ones_like (x , dtype = dtype )
4582
+ def __init__ (self , dtype = None ):
4583
+ super ().__init__ ()
4584
+ self .dtype = (
4585
+ backend .standardize_dtype (dtype ) if dtype is not None else None
4586
+ )
4584
4587
4585
- def compute_output_spec (self , x , dtype = None ):
4586
- if dtype is None :
4587
- dtype = x .dtype
4588
- return KerasTensor (x .shape , dtype = dtype )
4588
+ def call (self , x ):
4589
+ return backend .numpy .ones_like (x , dtype = self .dtype )
4590
+
4591
+ def compute_output_spec (self , x ):
4592
+ dtype = x .dtype if self .dtype is None else self .dtype
4593
+ sparse = getattr (x , "sparse" , False )
4594
+ return KerasTensor (x .shape , dtype = dtype , sparse = sparse )
4589
4595
4590
4596
4591
4597
@keras_export (["keras.ops.ones_like" , "keras.ops.numpy.ones_like" ])
@@ -4600,18 +4606,24 @@ def ones_like(x, dtype=None):
4600
4606
A tensor of ones with the same shape and type as `x`.
4601
4607
"""
4602
4608
if any_symbolic_tensors ((x ,)):
4603
- return OnesLike ().symbolic_call (x , dtype = dtype )
4609
+ return OnesLike (dtype = dtype ).symbolic_call (x )
4604
4610
return backend .numpy .ones_like (x , dtype = dtype )
4605
4611
4606
4612
4607
4613
class ZerosLike (Operation ):
4608
- def call (self , x , dtype = None ):
4609
- return backend .numpy .zeros_like (x , dtype = dtype )
4614
+ def __init__ (self , dtype = None ):
4615
+ super ().__init__ ()
4616
+ self .dtype = (
4617
+ backend .standardize_dtype (dtype ) if dtype is not None else None
4618
+ )
4619
+
4620
+ def call (self , x ):
4621
+ return backend .numpy .zeros_like (x , dtype = self .dtype )
4610
4622
4611
4623
def compute_output_spec (self , x , dtype = None ):
4612
- if dtype is None :
4613
- dtype = x . dtype
4614
- return KerasTensor (x .shape , dtype = dtype )
4624
+ dtype = x . dtype if self . dtype is None else self . dtype
4625
+ sparse = getattr ( x , "sparse" , False )
4626
+ return KerasTensor (x .shape , dtype = dtype , sparse = sparse )
4615
4627
4616
4628
4617
4629
@keras_export (
@@ -4631,7 +4643,7 @@ def zeros_like(x, dtype=None):
4631
4643
A tensor of zeros with the same shape and type as `x`.
4632
4644
"""
4633
4645
if any_symbolic_tensors ((x ,)):
4634
- return ZerosLike ().symbolic_call (x , dtype = dtype )
4646
+ return ZerosLike (dtype = dtype ).symbolic_call (x )
4635
4647
return backend .numpy .zeros_like (x , dtype = dtype )
4636
4648
4637
4649
0 commit comments