Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9515f5f

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Avoid using private APIs for determining if eager execution is enabled.
PiperOrigin-RevId: 193681604
1 parent 15bd9e3 commit 9515f5f

File tree

5 files changed

+16
-23
lines changed

5 files changed

+16
-23
lines changed

tensor2tensor/data_generators/image_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131

3232
import tensorflow as tf
3333

34-
from tensorflow.python.eager import context
35-
3634

3735
def resize_by_area(img, size):
3836
"""image resize function used by quite a few image problems."""
@@ -159,7 +157,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
159157

160158

161159
def encode_images_as_png(images):
162-
if context.in_eager_mode():
160+
if tf.contrib.eager.in_eager_mode():
163161
for image in images:
164162
yield tf.image.encode_png(image).numpy()
165163
else:

tensor2tensor/layers/common_layers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
import tensorflow as tf
3434

35-
from tensorflow.python.eager import context as tfe_context
3635
from tensorflow.python.framework import function
3736
from tensorflow.python.framework import ops
3837

@@ -265,7 +264,7 @@ def embedding(x,
265264
# On the backwards pass, we want to convert the gradient from
266265
# an indexed-slices to a regular tensor before sending it back to the
267266
# parameter server. This avoids excess computation on the parameter server.
268-
if not tfe_context.in_eager_mode():
267+
if not tf.contrib.eager.in_eager_mode():
269268
embedding_var = eu.convert_gradient_to_tensor(embedding_var)
270269
x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
271270
emb_x = gather(embedding_var, x, dtype)
@@ -2541,7 +2540,7 @@ def ones_matrix_band_part(rows, cols, num_lower, num_upper, out_shape=None):
25412540
def reshape_like_all_dims(a, b):
25422541
"""Reshapes a to match the shape of b."""
25432542
ret = tf.reshape(a, tf.shape(b))
2544-
if not tfe_context.in_eager_mode():
2543+
if not tf.contrib.eager.in_eager_mode():
25452544
ret.set_shape(b.get_shape())
25462545
return ret
25472546

tensor2tensor/layers/modalities.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
import tensorflow as tf
3131

32-
from tensorflow.python.eager import context
33-
3432

3533
@registry.register_symbol_modality("default")
3634
class SymbolModality(modality.Modality):
@@ -97,7 +95,7 @@ def _get_weights(self, hidden_dim=None):
9795
else:
9896
ret = tf.concat(shards, 0)
9997
# Convert ret to tensor.
100-
if not context.in_eager_mode():
98+
if not tf.contrib.eager.in_eager_mode():
10199
ret = eu.convert_gradient_to_tensor(ret)
102100
return ret
103101

@@ -211,13 +209,13 @@ class ImageModality(modality.Modality):
211209
def bottom(self, inputs):
212210
with tf.variable_scope(self.name):
213211
inputs = tf.to_float(inputs)
214-
if not context.in_eager_mode():
212+
if not tf.contrib.eager.in_eager_mode():
215213
tf.summary.image("inputs", inputs, max_outputs=2)
216214
return inputs
217215

218216
def targets_bottom(self, inputs):
219217
with tf.variable_scope(self.name):
220-
if not context.in_eager_mode():
218+
if not tf.contrib.eager.in_eager_mode():
221219
tf.summary.image("targets_bottom",
222220
tf.cast(inputs, tf.uint8), max_outputs=1)
223221
inputs_shape = common_layers.shape_list(inputs)
@@ -466,7 +464,7 @@ def bottom(self, inputs):
466464
raise ValueError("Assuming videos given as tensors in the format "
467465
"[batch, time, height, width, channels] but got one "
468466
"of shape: %s" % str(inputs_shape))
469-
if not context.in_eager_mode():
467+
if not tf.contrib.eager.in_eager_mode():
470468
tf.summary.image("inputs", tf.cast(inputs[:, -1, :, :, :], tf.uint8),
471469
max_outputs=1)
472470
# Standardize frames.
@@ -487,7 +485,7 @@ def targets_bottom(self, inputs):
487485
raise ValueError("Assuming videos given as tensors in the format "
488486
"[batch, time, height, width, channels] but got one "
489487
"of shape: %s" % str(inputs_shape))
490-
if not context.in_eager_mode():
488+
if not tf.contrib.eager.in_eager_mode():
491489
tf.summary.image(
492490
"targets_bottom", tf.cast(inputs[:, -1, :, :, :], tf.uint8),
493491
max_outputs=1)

tensor2tensor/utils/expert_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from six.moves import zip # pylint: disable=redefined-builtin
3434
import tensorflow as tf
3535

36-
from tensorflow.python.eager import context
3736
from tensorflow.python.framework import function
3837

3938
DEFAULT_DEV_STRING = "existing_device"
@@ -560,7 +559,7 @@ def remove(self, x):
560559
x,
561560
indices=self.nonpad_ids,
562561
)
563-
if not context.in_eager_mode():
562+
if not tf.contrib.eager.in_eager_mode():
564563
# This is a hack but for some reason, gather_nd return a tensor of
565564
# undefined shape, so the shape is set up manually
566565
x.set_shape([None] + x_shape[1:])
@@ -909,15 +908,15 @@ def my_fn(x):
909908
def reshape_like(a, b):
910909
"""Reshapes a to match the shape of b in all but the last dimension."""
911910
ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0))
912-
if not context.in_eager_mode():
911+
if not tf.contrib.eager.in_eager_mode():
913912
ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:])
914913
return ret
915914

916915

917916
def flatten_all_but_last(a):
918917
"""Flatten all dimensions of a except the last."""
919918
ret = tf.reshape(a, [-1, tf.shape(a)[-1]])
920-
if not context.in_eager_mode():
919+
if not tf.contrib.eager.in_eager_mode():
921920
ret.set_shape([None] + a.get_shape().as_list()[-1:])
922921
return ret
923922

tensor2tensor/utils/t2t_model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444

4545
import tensorflow as tf
4646

47-
from tensorflow.python.eager import context
4847
from tensorflow.python.layers import base
4948
from tensorflow.python.ops import variable_scope
5049

@@ -717,7 +716,7 @@ def _slow_greedy_infer(self, features, decode_length):
717716

718717
def infer_step(recent_output, recent_logits, unused_loss):
719718
"""Inference step."""
720-
if not context.in_eager_mode():
719+
if not tf.contrib.eager.in_eager_mode():
721720
recent_output.set_shape([None, None, None, 1])
722721
padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
723722
features["targets"] = padded
@@ -733,7 +732,7 @@ def infer_step(recent_output, recent_logits, unused_loss):
733732
common_layers.shape_list(recent_output)[1], :, :]
734733
cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
735734
samples = tf.concat([recent_output, cur_sample], axis=1)
736-
if not context.in_eager_mode():
735+
if not tf.contrib.eager.in_eager_mode():
737736
samples.set_shape([None, None, None, 1])
738737

739738
# Assuming we have one shard for logits.
@@ -765,7 +764,7 @@ def infer_step(recent_output, recent_logits, unused_loss):
765764
result = initial_output
766765
# tensor of shape [batch_size, time, 1, 1, vocab_size]
767766
logits = tf.zeros((batch_size, 0, 1, 1, target_modality.top_dimensionality))
768-
if not context.in_eager_mode():
767+
if not tf.contrib.eager.in_eager_mode():
769768
logits.set_shape([None, None, None, None, None])
770769
loss = 0.0
771770

@@ -1304,7 +1303,7 @@ def as_default(self):
13041303

13051304

13061305
def create_eager_var_store():
1307-
if context.in_eager_mode():
1306+
if tf.contrib.eager.in_eager_mode():
13081307
return variable_scope.EagerVariableStore()
13091308
else:
13101309
return DummyVariableStore()
@@ -1405,7 +1404,7 @@ def summarize_features(features, num_shards=1):
14051404

14061405

14071406
def _eager_log(level, *args):
1408-
if context.in_eager_mode() and args in _already_logged:
1407+
if tf.contrib.eager.in_eager_mode() and args in _already_logged:
14091408
return
14101409
_already_logged.add(args)
14111410
getattr(tf.logging, level)(*args)

0 commit comments

Comments
 (0)