Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0a48409

Browse files
author
Ryan Sepassi
committed
Cloud fixes and TF version bumps
PiperOrigin-RevId: 197457246
1 parent 4af78a7 commit 0a48409

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

tensor2tensor/utils/cloud_mlengine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
FLAGS = tf.flags.FLAGS
3434

3535
CONSOLE_URL = "https://console.cloud.google.com/mlengine/jobs/"
36+
RUNTIME_VERSION = "1.8"
3637

3738
# TODO(rsepassi):
3839
# * Enable multi-machine sync/async training
@@ -86,7 +87,7 @@ def flags_as_args():
8687
continue
8788
if name.startswith("autotune"):
8889
continue
89-
args.extend(["--%s" % name, str(val)])
90+
args.extend(["--%s=%s" % (name, str(val))])
9091
return args
9192

9293

@@ -113,7 +114,7 @@ def configure_job():
113114
"pythonModule": "tensor2tensor.bin.t2t_trainer",
114115
"args": flags_as_args(),
115116
"region": text_encoder.native_to_unicode(cloud.default_region()),
116-
"runtimeVersion": "1.5",
117+
"runtimeVersion": RUNTIME_VERSION,
117118
"pythonVersion": "3.5" if sys.version_info.major == 3 else "2.7",
118119
"jobDir": FLAGS.output_dir,
119120
"scaleTier": "CUSTOM",

tensor2tensor/utils/cloud_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def cloud_tpu(vm_name, tpu_name, delete_on_done=False, skip_confirmation=False):
162162
class Gcloud(object):
163163
"""gcloud command strings."""
164164
# Note these can be modified by set_versions
165-
VM_VERSION = "tf-1-7"
166-
TPU_VERSION = "1.7"
165+
VM_VERSION = "tf-1-8"
166+
TPU_VERSION = "1.8"
167167

168168
@classmethod
169169
def set_versions(cls, vm, tpu):

tensor2tensor/utils/optimize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def get_variable_initializer(hparams):
215215
if not hparams.initializer:
216216
return None
217217

218-
tf.logging.info("Using variable initializer: %s", hparams.initializer)
218+
if not tf.contrib.eager.in_eager_mode():
219+
tf.logging.info("Using variable initializer: %s", hparams.initializer)
219220
if hparams.initializer == "orthogonal":
220221
return tf.orthogonal_initializer(gain=hparams.initializer_gain)
221222
elif hparams.initializer == "uniform":

tensor2tensor/utils/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def parse_problem_name(problem_name):
263263
base_name, was_reversed, was_copy = parse_problem_name(name)
264264

265265
if base_name not in _PROBLEMS:
266-
all_problem_names = sorted(list_problems())
266+
all_problem_names = list_problems()
267267
error_lines = ["%s not in the set of supported problems:" % base_name
268268
] + all_problem_names
269269
error_msg = "\n * ".join(error_lines)
@@ -272,7 +272,7 @@ def parse_problem_name(problem_name):
272272

273273

274274
def list_problems():
275-
return list(_PROBLEMS)
275+
return sorted(list(_PROBLEMS))
276276

277277

278278
def _internal_get_modality(name, mod_collection, collection_str):

0 commit comments

Comments
 (0)