Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3704b1f

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
Correct inference to run with RealModality for time series problems.
PiperOrigin-RevId: 200667028
1 parent ffff8ae commit 3704b1f

File tree

4 files changed

+45
-11
lines changed

4 files changed

+45
-11
lines changed

tensor2tensor/data_generators/text_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ def decode(self, ids, strip_extraneous=False):
10511051
ValueError: if the ids are not of the appropriate size.
10521052
"""
10531053
del strip_extraneous
1054-
return " ".join(ids)
1054+
return " ".join([str(i) for i in ids])
10551055

10561056

10571057
def strip_ids(ids, ids_to_strip):

tensor2tensor/layers/modalities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,10 @@ class RealModality(modality.Modality):
739739
* Top is a linear projection layer to vocab_size.
740740
"""
741741

742+
@property
743+
def top_is_pointwise(self):
744+
return True
745+
742746
def bottom(self, x):
743747
with tf.variable_scope("real"):
744748
return tf.layers.dense(tf.to_float(x), self._body_input_depth,

tensor2tensor/models/transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def _greedy_infer(self, features, decode_length):
213213
Raises:
214214
NotImplementedError: If there are multiple data shards.
215215
"""
216+
# For real-valued modalities use the slow decode path for now.
217+
if self._target_modality_is_real:
218+
return super(Transformer, self)._greedy_infer(features, decode_length)
216219
with tf.variable_scope(self.name):
217220
return self._fast_decode(features, decode_length)
218221

tensor2tensor/utils/t2t_model.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def _custom_getter(self):
141141
else:
142142
return None
143143

144+
@property
145+
def _target_modality_is_real(self):
146+
"""Whether the target modality is real-valued."""
147+
target_modality = self._problem_hparams.target_modality
148+
return target_modality.name.startswith("real_")
149+
144150
def call(self, inputs, **kwargs):
145151
del kwargs
146152
features = inputs
@@ -732,7 +738,11 @@ def _slow_greedy_infer(self, features, decode_length):
732738
def infer_step(recent_output, recent_logits, unused_loss):
733739
"""Inference step."""
734740
if not tf.contrib.eager.in_eager_mode():
735-
recent_output.set_shape([None, None, None, 1])
741+
if self._target_modality_is_real:
742+
dim = self._problem_hparams.target_modality.top_dimensionality
743+
recent_output.set_shape([None, None, None, dim])
744+
else:
745+
recent_output.set_shape([None, None, None, 1])
736746
padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
737747
features["targets"] = padded
738748
# This is inefficient in that it generates samples at all timesteps,
@@ -745,10 +755,14 @@ def infer_step(recent_output, recent_logits, unused_loss):
745755
else:
746756
cur_sample = samples[:,
747757
common_layers.shape_list(recent_output)[1], :, :]
748-
cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
749-
samples = tf.concat([recent_output, cur_sample], axis=1)
750-
if not tf.contrib.eager.in_eager_mode():
751-
samples.set_shape([None, None, None, 1])
758+
if self._target_modality_is_real:
759+
cur_sample = tf.expand_dims(cur_sample, axis=1)
760+
samples = tf.concat([recent_output, cur_sample], axis=1)
761+
else:
762+
cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
763+
samples = tf.concat([recent_output, cur_sample], axis=1)
764+
if not tf.contrib.eager.in_eager_mode():
765+
samples.set_shape([None, None, None, 1])
752766

753767
# Assuming we have one shard for logits.
754768
logits = tf.concat([recent_logits, logits[:, -1:]], 1)
@@ -764,7 +778,11 @@ def infer_step(recent_output, recent_logits, unused_loss):
764778
batch_size = common_layers.shape_list(initial_output)[0]
765779
else:
766780
batch_size = common_layers.shape_list(features["inputs"])[0]
767-
initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64)
781+
if self._target_modality_is_real:
782+
dim = self._problem_hparams.target_modality.top_dimensionality
783+
initial_output = tf.zeros((batch_size, 0, 1, dim), dtype=tf.float32)
784+
else:
785+
initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64)
768786
# Hack: foldl complains when the output shape is less specified than the
769787
# input shape, so we confuse it about the input shape.
770788
initial_output = tf.slice(initial_output, [0, 0, 0, 0],
@@ -783,10 +801,17 @@ def infer_step(recent_output, recent_logits, unused_loss):
783801

784802
# Initial values of result, logits and loss.
785803
result = initial_output
786-
# tensor of shape [batch_size, time, 1, 1, vocab_size]
787-
logits = tf.zeros((batch_size, 0, 1, 1, target_modality.top_dimensionality))
804+
if self._target_modality_is_real:
805+
logits = tf.zeros((batch_size, 0, 1, target_modality.top_dimensionality))
806+
logits_shape_inv = [None, None, None, None]
807+
else:
808+
# tensor of shape [batch_size, time, 1, 1, vocab_size]
809+
logits = tf.zeros((batch_size, 0, 1, 1,
810+
target_modality.top_dimensionality))
811+
logits_shape_inv = [None, None, None, None, None]
788812
if not tf.contrib.eager.in_eager_mode():
789-
logits.set_shape([None, None, None, None, None])
813+
logits.set_shape(logits_shape_inv)
814+
790815
loss = 0.0
791816

792817
def while_exit_cond(result, logits, loss): # pylint: disable=unused-argument
@@ -822,7 +847,7 @@ def fn_not_eos():
822847
infer_step, [result, logits, loss],
823848
shape_invariants=[
824849
tf.TensorShape([None, None, None, None]),
825-
tf.TensorShape([None, None, None, None, None]),
850+
tf.TensorShape(logits_shape_inv),
826851
tf.TensorShape([]),
827852
],
828853
back_prop=False,
@@ -857,6 +882,8 @@ def sample(self, features):
857882
losses: a dictionary: {loss-name (string): floating point `Scalar`}.
858883
"""
859884
logits, losses = self(features) # pylint: disable=not-callable
885+
if self._target_modality_is_real:
886+
return logits, logits, losses # Raw numbers returned from real modality.
860887
if self.hparams.sampling_method == "argmax":
861888
samples = tf.argmax(logits, axis=-1)
862889
else:

0 commit comments

Comments
 (0)