Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit cd6b6c5

Browse files
nshazeerRyan Sepassi
authored andcommitted
update comments on dtype hparams.
PiperOrigin-RevId: 193547090
1 parent 88eb6b0 commit cd6b6c5

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tensor2tensor/layers/common_hparams.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,19 @@ def basic_params1():
228228
force_full_predict=False,
229229
# Set this for pure model parallelism. There is only one data shard.
230230
no_data_parallelism=False,
231-
# Set this to the dtype used for activation. Variables will still be
232-
# stored in float32.
231+
# dtype used for activations. - "float32" or "bfloat16"
232+
# activation_dtype="bfloat16" currently only works on TPU.
233+
# It lowers activation-memory usage
234+
# and does not appear to affect quality.
235+
# You can train on TPU with activation_dtype="bfloat16" and evaluate
236+
# on CPU/GPU with activation_dtype="float32"
233237
activation_dtype="float32",
234-
# Experimental: set weight_dtype="bfloat16" to use bfloat16 for both
235-
# weights and activations. Model quality may be worse. Model quality
236-
# appears to be close to baseline with large batch sizes (>4k).
238+
# dtype used for parameters: "float32" or "bfloat16"
239+
# bfloat16 currently only works with optimizer="adafactor".
240+
# The savings in memory allow for training larger models.
241+
# Weights are encoded as (w*128)^8, using pseudostochastic
242+
# roundoff. Initial experiments show that model quality is similar
243+
# to baseline for about 3M training steps, but worse thereafter.
237244
weight_dtype="float32",
238245
)
239246

0 commit comments

Comments
 (0)