Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7c072d7

Browse files
author
Ryan Sepassi
committed
Revert usage of Datasets API
PiperOrigin-RevId: 163421122
1 parent 01f245f commit 7c072d7

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tensor2tensor/utils/data_reader.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import math
2222
import os
23-
import random
2423

2524
# Dependency imports
2625

@@ -114,17 +113,18 @@ def decode_record(record):
114113
return dict(zip(decode_items, decoded))
115114

116115
with tf.name_scope("examples_in"):
116+
# Read serialized examples using slim parallel_reader.
117117
data_files = tf.contrib.slim.parallel_reader.get_data_files(data_sources)
118-
if training:
119-
random.shuffle(data_files)
120-
dataset = tf.contrib.data.TFRecordDataset(data_files)
121118
num_readers = min(4 if training else 1, len(data_files))
122-
dataset = dataset.map(decode_record, num_threads=num_readers)
123-
if training:
124-
dataset = dataset.shuffle(capacity)
125-
dataset = dataset.repeat(None if training else 1)
126-
it = dataset.make_one_shot_iterator()
127-
return it.get_next()
119+
_, example_serialized = tf.contrib.slim.parallel_reader.parallel_read(
120+
data_sources,
121+
tf.TFRecordReader,
122+
num_epochs=None if training else 1,
123+
shuffle=training,
124+
capacity=2 * capacity,
125+
min_after_dequeue=capacity,
126+
num_readers=num_readers)
127+
return decode_record(example_serialized)
128128

129129

130130
def preprocessing(examples, data_file_pattern, mode):

0 commit comments

Comments
 (0)