@@ -141,6 +141,12 @@ def _custom_getter(self):
141
141
else :
142
142
return None
143
143
144
+ @property
145
+ def _target_modality_is_real (self ):
146
+ """Whether the target modality is real-valued."""
147
+ target_modality = self ._problem_hparams .target_modality
148
+ return target_modality .name .startswith ("real_" )
149
+
144
150
def call (self , inputs , ** kwargs ):
145
151
del kwargs
146
152
features = inputs
@@ -732,7 +738,11 @@ def _slow_greedy_infer(self, features, decode_length):
732
738
def infer_step (recent_output , recent_logits , unused_loss ):
733
739
"""Inference step."""
734
740
if not tf .contrib .eager .in_eager_mode ():
735
- recent_output .set_shape ([None , None , None , 1 ])
741
+ if self ._target_modality_is_real :
742
+ dim = self ._problem_hparams .target_modality .top_dimensionality
743
+ recent_output .set_shape ([None , None , None , dim ])
744
+ else :
745
+ recent_output .set_shape ([None , None , None , 1 ])
736
746
padded = tf .pad (recent_output , [[0 , 0 ], [0 , 1 ], [0 , 0 ], [0 , 0 ]])
737
747
features ["targets" ] = padded
738
748
# This is inefficient in that it generates samples at all timesteps,
@@ -745,10 +755,14 @@ def infer_step(recent_output, recent_logits, unused_loss):
745
755
else :
746
756
cur_sample = samples [:,
747
757
common_layers .shape_list (recent_output )[1 ], :, :]
748
- cur_sample = tf .to_int64 (tf .expand_dims (cur_sample , axis = 1 ))
749
- samples = tf .concat ([recent_output , cur_sample ], axis = 1 )
750
- if not tf .contrib .eager .in_eager_mode ():
751
- samples .set_shape ([None , None , None , 1 ])
758
+ if self ._target_modality_is_real :
759
+ cur_sample = tf .expand_dims (cur_sample , axis = 1 )
760
+ samples = tf .concat ([recent_output , cur_sample ], axis = 1 )
761
+ else :
762
+ cur_sample = tf .to_int64 (tf .expand_dims (cur_sample , axis = 1 ))
763
+ samples = tf .concat ([recent_output , cur_sample ], axis = 1 )
764
+ if not tf .contrib .eager .in_eager_mode ():
765
+ samples .set_shape ([None , None , None , 1 ])
752
766
753
767
# Assuming we have one shard for logits.
754
768
logits = tf .concat ([recent_logits , logits [:, - 1 :]], 1 )
@@ -764,7 +778,11 @@ def infer_step(recent_output, recent_logits, unused_loss):
764
778
batch_size = common_layers .shape_list (initial_output )[0 ]
765
779
else :
766
780
batch_size = common_layers .shape_list (features ["inputs" ])[0 ]
767
- initial_output = tf .zeros ((batch_size , 0 , 1 , 1 ), dtype = tf .int64 )
781
+ if self ._target_modality_is_real :
782
+ dim = self ._problem_hparams .target_modality .top_dimensionality
783
+ initial_output = tf .zeros ((batch_size , 0 , 1 , dim ), dtype = tf .float32 )
784
+ else :
785
+ initial_output = tf .zeros ((batch_size , 0 , 1 , 1 ), dtype = tf .int64 )
768
786
# Hack: foldl complains when the output shape is less specified than the
769
787
# input shape, so we confuse it about the input shape.
770
788
initial_output = tf .slice (initial_output , [0 , 0 , 0 , 0 ],
@@ -783,10 +801,17 @@ def infer_step(recent_output, recent_logits, unused_loss):
783
801
784
802
# Initial values of result, logits and loss.
785
803
result = initial_output
786
- # tensor of shape [batch_size, time, 1, 1, vocab_size]
787
- logits = tf .zeros ((batch_size , 0 , 1 , 1 , target_modality .top_dimensionality ))
804
+ if self ._target_modality_is_real :
805
+ logits = tf .zeros ((batch_size , 0 , 1 , target_modality .top_dimensionality ))
806
+ logits_shape_inv = [None , None , None , None ]
807
+ else :
808
+ # tensor of shape [batch_size, time, 1, 1, vocab_size]
809
+ logits = tf .zeros ((batch_size , 0 , 1 , 1 ,
810
+ target_modality .top_dimensionality ))
811
+ logits_shape_inv = [None , None , None , None , None ]
788
812
if not tf .contrib .eager .in_eager_mode ():
789
- logits .set_shape ([None , None , None , None , None ])
813
+ logits .set_shape (logits_shape_inv )
814
+
790
815
loss = 0.0
791
816
792
817
def while_exit_cond (result , logits , loss ): # pylint: disable=unused-argument
@@ -822,7 +847,7 @@ def fn_not_eos():
822
847
infer_step , [result , logits , loss ],
823
848
shape_invariants = [
824
849
tf .TensorShape ([None , None , None , None ]),
825
- tf .TensorShape ([ None , None , None , None , None ] ),
850
+ tf .TensorShape (logits_shape_inv ),
826
851
tf .TensorShape ([]),
827
852
],
828
853
back_prop = False ,
@@ -857,6 +882,8 @@ def sample(self, features):
857
882
losses: a dictionary: {loss-name (string): floating point `Scalar`}.
858
883
"""
859
884
logits , losses = self (features ) # pylint: disable=not-callable
885
+ if self ._target_modality_is_real :
886
+ return logits , logits , losses # Raw numbers returned from real modality.
860
887
if self .hparams .sampling_method == "argmax" :
861
888
samples = tf .argmax (logits , axis = - 1 )
862
889
else :
0 commit comments