This repository was archived by the owner on Jul 7, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -228,12 +228,19 @@ def basic_params1():
228
228
force_full_predict = False ,
229
229
# Set this for pure model parallelism. There is only one data shard.
230
230
no_data_parallelism = False ,
231
- # Set this to the dtype used for activation. Variables will still be
232
- # stored in float32.
231
+ # dtype used for activations. - "float32" or "bfloat16"
232
+ # activation_dtype="bfloat16" currently only works on TPU.
233
+ # It lowers activation-memory usage
234
+ # and does not appear to affect quality.
235
+ # You can train on TPU with activation_dtype="bfloat16" and evaluate
236
+ # on CPU/GPU with activation_dtype="float32"
233
237
activation_dtype = "float32" ,
234
- # Experimental: set weight_dtype="bfloat16" to use bfloat16 for both
235
- # weights and activations. Model quality may be worse. Model quality
236
- # appears to be close to baseline with large batch sizes (>4k).
238
+ # dtype used for parameters: "float32" or "bfloat16"
239
+ # bfloat16 currently only works with optimizer="adafactor".
240
+ # The savings in memory allow for training larger models.
241
+ # Weights are encoded as (w*128)^8, using pseudostochastic
242
+ # roundoff. Initial experiments show that model quality is similar
243
+ # to baseline for about 3M training steps, but worse thereafter.
237
244
weight_dtype = "float32" ,
238
245
)
239
246
You can’t perform that action at this time.
0 commit comments