Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9094fc1

Browse files
rasmirsepassi
authored andcommitted
Fixed small issues with ML Engine script. (#670)
* Fixed typo in flags. * Fixed indentation. * Clarified machine type documentation. * Disable cache discovery. See googleapis/google-api-python-client#299
1 parent 2872bd0 commit 9094fc1

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

docs/cloud_mlengine.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ machines with 4 or 8 GPUs.
2828
You can additionally pass the `--cloud_mlengine_master_type` to select another
2929
kind of machine (see the [docs for
3030
`masterType`](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput)
31-
for your options). If you provide this flag yourself, make sure you pass the
32-
correct value for `--worker_gpu`.
31+
for options, including
32+
[ML Engine machine types](https://cloud.google.com/ml-engine/docs/training-overview)
33+
and their
34+
[specs](https://cloud.google.com/compute/docs/machine-types)).
35+
If you provide this flag yourself, make sure you pass the
36+
correct value for `--worker_gpu` (for non-GPU machines, you must explicitly pass `--worker_gpu=0`).
3337

3438
**Note**: `t2t-trainer` only currently supports launching with single machines,
3539
possibly with multiple GPUs. Multi-machine setups are not yet supported out of

tensor2tensor/utils/cloud_mlengine.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def launch_job(job_spec):
140140
"""Launch job on ML Engine."""
141141
project_id = 'projects/{}'.format(cloud.default_project())
142142
credentials = GoogleCredentials.get_application_default()
143-
cloudml = discovery.build('ml', 'v1', credentials=credentials)
143+
cloudml = discovery.build(
144+
'ml', 'v1', credentials=credentials, cache_discovery=False)
144145
request = cloudml.projects().jobs().create(body=job_spec, parent=project_id)
145146
request.execute()
146147

@@ -275,13 +276,13 @@ def validate_flags():
275276
assert FLAGS.cloud_mlengine_master_type == 'standard_tpu'
276277
elif FLAGS.worker_gpu:
277278
if FLAGS.worker_gpu == 1:
278-
assert FLAGS.cloud_ml_engine_master_type in ['standard_gpu',
279-
'standard_p100']
279+
assert FLAGS.cloud_mlengine_master_type in ['standard_gpu',
280+
'standard_p100']
280281
elif FLAGS.worker_gpu == 4:
281-
assert FLAGS.cloud_ml_engine_master_type in ['complex_model_m_gpu',
282-
'complex_model_m_p100']
282+
assert FLAGS.cloud_mlengine_master_type in ['complex_model_m_gpu',
283+
'complex_model_m_p100']
283284
else:
284-
assert FLAGS.cloud_ml_engine_master_type == 'complex_model_l_gpu'
285+
assert FLAGS.cloud_mlengine_master_type == 'complex_model_l_gpu'
285286
else:
286287
assert FLAGS.cloud_mlengine_master_type in ['standard', 'large_model',
287288
'complex_model_s',

0 commit comments

Comments
 (0)