Skip to content

Commit 09890a8

Browse files
yangustc07tensorflower-gardener
authored andcommitted
Migrate tf.data.experimental.cardinality to tf.data.Dataset.cardinality.
PiperOrigin-RevId: 733923403
1 parent 8e69121 commit 09890a8

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

tf_keras/engine/data_adapter.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,7 @@ def should_recreate_iterator(self):
794794
# each epoch.
795795
return (
796796
self._user_steps is None
797-
or tf.data.experimental.cardinality(self._dataset).numpy()
798-
== self._user_steps
797+
or self._dataset.cardinality().numpy() == self._user_steps
799798
)
800799

801800
def _validate_args(self, y, sample_weights, steps, pss_evaluation_shards):
@@ -819,8 +818,8 @@ def _validate_args(self, y, sample_weights, steps, pss_evaluation_shards):
819818
"specify the number of steps to run."
820819
)
821820
else:
822-
size = tf.data.experimental.cardinality(self._dataset).numpy()
823-
if size == tf.data.experimental.INFINITE_CARDINALITY:
821+
size = self._dataset.cardinality().numpy()
822+
if size == tf.data.INFINITE_CARDINALITY:
824823
if pss_evaluation_shards:
825824
raise ValueError(
826825
"When performing exact evaluation, the dataset "
@@ -1481,8 +1480,8 @@ def _infer_steps(self, steps, dataset):
14811480
if not isinstance(dataset, tf.data.Dataset):
14821481
return None
14831482

1484-
size = tf.data.experimental.cardinality(dataset)
1485-
if size == tf.data.experimental.INFINITE_CARDINALITY and steps is None:
1483+
size = dataset.cardinality()
1484+
if size == tf.data.INFINITE_CARDINALITY and steps is None:
14861485
raise ValueError(
14871486
"When passing an infinitely repeating dataset, please specify "
14881487
"a `steps_per_epoch` value so that epoch level "

tf_keras/engine/data_adapter_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,8 +1198,8 @@ def test_unknown_cardinality_dataset_with_steps_per_epoch(self):
11981198
ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
11991199
filtered_ds = ds.filter(lambda x: x < 4)
12001200
self.assertEqual(
1201-
tf.data.experimental.cardinality(filtered_ds).numpy(),
1202-
tf.data.experimental.UNKNOWN_CARDINALITY,
1201+
filtered_ds.cardinality().numpy(),
1202+
tf.data.UNKNOWN_CARDINALITY,
12031203
)
12041204

12051205
# User can choose to only partially consume `Dataset`.
@@ -1221,8 +1221,8 @@ def test_unknown_cardinality_dataset_without_steps_per_epoch(self):
12211221
ds = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4, 5, 6])
12221222
filtered_ds = ds.filter(lambda x: x < 4)
12231223
self.assertEqual(
1224-
tf.data.experimental.cardinality(filtered_ds).numpy(),
1225-
tf.data.experimental.UNKNOWN_CARDINALITY,
1224+
filtered_ds.cardinality().numpy(),
1225+
tf.data.UNKNOWN_CARDINALITY,
12261226
)
12271227

12281228
data_handler = data_adapter.DataHandler(

0 commit comments

Comments
 (0)