Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2c48f89

Browse files
eli7copybara-github
authored andcommitted
Fixing feature encoder for tf.string variable length features.
PiperOrigin-RevId: 314208560
1 parent c104976 commit 2c48f89

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,10 +718,14 @@ def decode_example(self, serialized_example):
718718
[1], tf.int64, getattr(self._hparams, "sampling_keep_top_k", -1))
719719

720720
if data_items_to_decoders is None:
721-
data_items_to_decoders = {
722-
field: contrib.slim().tfexample_decoder.Tensor(field)
723-
for field in data_fields
724-
}
721+
data_items_to_decoders = {}
722+
for field in data_fields:
723+
if data_fields[field].dtype is tf.string:
724+
default_value = b""
725+
else:
726+
default_value = 0
727+
data_items_to_decoders[field] = contrib.slim().tfexample_decoder.Tensor(
728+
field, default_value=default_value)
725729

726730
decoder = contrib.slim().tfexample_decoder.TFExampleDecoder(
727731
data_fields, data_items_to_decoders)

0 commit comments

Comments
 (0)