Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 12c59a7

Browse files
nshazeerRyan Sepassi
authored andcommitted
Massively simplify expert_utils. Breaks checkpoints for models that use experts. Fixed bug in Parallelism, where caching devices were always used, even when none. Fixed bug in attention_lm, attention_lm_moe by setting the default norm_type to "layer" instead of "none".
PiperOrigin-RevId: 164869403
1 parent d30ec6b commit 12c59a7

File tree

10 files changed

+313
-1177
lines changed

10 files changed

+313
-1177
lines changed

tensor2tensor/layers/common_hparams.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def basic_params1():
6969
sampling_method="argmax", # "argmax" or "random"
7070
problem_choice="adaptive", # "uniform", "adaptive", "distributed"
7171
multiply_embedding_mode="sqrt_depth",
72+
# Parameters related to mixtures of experts.
73+
moe_hidden_sizes="2048", # hidden layer sizes (comma-separated)
74+
moe_num_experts=64, # number of experts per layer
75+
moe_k=2, # how many experts to use for each batch element
76+
moe_loss_coef=1e-2,
7277
# Sequences of operations to perform on layer input and layer output.
7378
# Used by common_layers.layer_preprocess, common_layers.layer_postprocess
7479
# Each character repsesnts an operation:
@@ -83,7 +88,7 @@ def basic_params1():
8388
# dropout rate to use during layer_preprocess and layer_postprocess
8489
layer_prepostprocess_dropout=0.1,
8590
# What type of normalization to use
86-
norm_type="none", # "batch", layer", "noam", "none".
91+
norm_type="layer", # "batch", layer", "noam", "none".
8792
# epsilon parameter to normalization function
8893
norm_epsilon=1e-6,
8994
symbol_modality_num_shards=16,

tensor2tensor/layers/common_layers.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0):
193193
# On the backwards pass, we want to convert the gradient from
194194
# an indexed-slices to a regular tensor before sending it back to the
195195
# parameter server. This avoids excess computation on the parameter server.
196-
embedding_var = eu.ConvertGradientToTensor(embedding_var)
196+
embedding_var = eu.convert_gradient_to_tensor(embedding_var)
197197
emb_x = tf.gather(embedding_var, x)
198198
if multiplier != 1.0:
199199
emb_x *= multiplier
@@ -823,71 +823,6 @@ def decompress_seqcnn(x,
823823
return tf.layers.dense(outputs, targets_vocab_size)
824824

825825

826-
def moe_layer(data_parallelism,
827-
ps_devices,
828-
xs,
829-
train,
830-
model_hidden_size,
831-
expert_hidden_size,
832-
n1,
833-
n2,
834-
loss_coef,
835-
autoscale=True,
836-
name=None):
837-
"""A mixture of experts layer.
838-
839-
Args:
840-
data_parallelism: a expert_utils.Parallelism object.
841-
ps_devices: a list of strings
842-
xs: a list of input tensors.
843-
train: a boolean scalar.
844-
model_hidden_size: an integer (input/output size for this layer)
845-
expert_hidden_size: an integer (size of each expert's hidden layer)
846-
n1: an integer - number of experts (or # of groups for hierarchical MoE)
847-
n2: optional integer - size of each group of experts for hierarchical MoE
848-
loss_coef: a scalar - multiplier on load-balancing losses
849-
autoscale: a boolean
850-
name: a string
851-
852-
Returns:
853-
ys: a list of tensors:
854-
extra_training_loss: a scalar
855-
"""
856-
dp = data_parallelism
857-
with tf.variable_scope(name, default_name="moe"):
858-
# Set up the hyperparameters for the gating networks.
859-
primary_gating_hp = eu.NoisyTopKGatingParams()
860-
primary_gating_hp.num_experts = n1
861-
if n2:
862-
# hierarchical MoE containing moe_n1 groups of moe_n2 experts.
863-
assert n2 > 1
864-
secondary_gating_hp = eu.NoisyTopKGatingParams()
865-
secondary_gating_hp.num_experts = n2
866-
else:
867-
# flat mixture of moe_n1 experts.
868-
secondary_gating_hp = None
869-
# Set up the hyperparameters for the expert networks.
870-
# Each expert contains a hidden RELU layer of size filter_size
871-
expert_hp = eu.FeedForwardExpertParams()
872-
expert_hp.autoscale = autoscale
873-
expert_hp.hidden_layer_sizes = [expert_hidden_size]
874-
# Create the mixture of experts.
875-
moe = eu.DistributedMixtureOfExperts(primary_gating_hp, secondary_gating_hp,
876-
expert_hp, model_hidden_size,
877-
model_hidden_size, ps_devices, "moe")
878-
# MoE expects input tensors to be 2d.
879-
# Flatten out spatial dimensions.
880-
xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n)
881-
# Call the MoE
882-
moe_out_2d, importance, load, _, _ = moe.Eval(
883-
dp.devices, xs_2d, train, identifiers=None)
884-
# Reshape the output to the original shape.
885-
moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs))
886-
# These losses encourage equal load on the different experts.
887-
loss = loss_coef * (eu.CVSquared(importance) + eu.CVSquared(load))
888-
return moe_out, loss
889-
890-
891826
def simple_attention(target, source, bias=None):
892827
"""A simple attention function.
893828

tensor2tensor/layers/modalities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _get_weights(self):
7070
ret = shards[0]
7171
else:
7272
ret = tf.concat(shards, 0)
73-
ret = eu.ConvertGradientToTensor(ret)
73+
ret = eu.convert_gradient_to_tensor(ret)
7474
return ret
7575

7676
def bottom_simple(self, x, name, reuse):

tensor2tensor/models/attention_lm_moe.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tensor2tensor.layers import common_attention
3333
from tensor2tensor.layers import common_hparams
3434
from tensor2tensor.layers import common_layers
35+
from tensor2tensor.utils import expert_utils
3536
from tensor2tensor.utils import registry
3637
from tensor2tensor.utils import t2t_model
3738

@@ -61,6 +62,7 @@ def postprocess(x, y):
6162
x = dp(tf.nn.dropout, decoder_input,
6263
1.0 - hparams.layer_prepostprocess_dropout)
6364
extra_loss = 0.0
65+
moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")]
6466
for layer in xrange(hparams.num_hidden_layers):
6567
with tf.variable_scope("layer_%d" % layer):
6668
with tf.variable_scope("attention"):
@@ -78,11 +80,18 @@ def postprocess(x, y):
7880
x = postprocess(x, y)
7981
with tf.variable_scope("ffn"):
8082
if str(layer) in hparams.moe_layers.split(","):
81-
y, loss = common_layers.moe_layer(
82-
dp, self._ps_devices, preprocess(x),
83+
y, loss = expert_utils.distributed_moe(
84+
dp,
85+
self._ps_devices,
86+
preprocess(x),
8387
hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
84-
hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1,
85-
hparams.moe_n2, hparams.moe_loss_coef)
88+
input_size=hparams.hidden_size,
89+
expert_fn=expert_utils.ffn_expert_fn(
90+
hparams.hidden_size, moe_hidden_sizes,
91+
hparams.hidden_size),
92+
num_experts=hparams.moe_num_experts,
93+
k=hparams.moe_k,
94+
loss_coef=hparams.moe_loss_coef)
8695
extra_loss += loss
8796
else:
8897
y = dp(
@@ -149,16 +158,7 @@ def attention_lm_moe_base():
149158
hparams.label_smoothing = 0.0
150159
hparams.shared_embedding_and_softmax_weights = int(False)
151160
hparams.add_hparam("filter_size", 2048) # Add new ones like this.
152-
# comma-separated list of layer numbers.
153-
# At each of these layers, we replace the ffn with a mixture of experts.
154-
hparams.add_hparam("moe_layers", "2")
155-
# If moe_n2 is None, then use a flat MoE with moe_n1 experts.
156-
# If moe_n2 is an integer, then use a hierarchical MoE
157-
# consisting of moe_n1 groups of moe_n2 experts each.
158-
hparams.add_hparam("moe_n1", 32)
159-
hparams.add_hparam("moe_n2", 0)
160-
hparams.add_hparam("moe_hidden_size", 2048)
161-
hparams.add_hparam("moe_loss_coef", 1e-2)
161+
hparams.moe_num_experts = 32
162162
# attention-related flags
163163
hparams.add_hparam("num_heads", 8)
164164
hparams.add_hparam("attention_key_channels", 0)
@@ -168,6 +168,7 @@ def attention_lm_moe_base():
168168
hparams.add_hparam("attention_dropout", 0.0)
169169
hparams.add_hparam("relu_dropout", 0.0)
170170
hparams.add_hparam("pos", "timing") # timing, none
171+
hparams.add_hparam("moe_layers", "2") # comma separated list of layer numbers
171172
return hparams
172173

173174

@@ -188,9 +189,20 @@ def attention_lm_moe_small():
188189
hparams.num_hidden_layers = 4
189190
hparams.hidden_size = 512
190191
hparams.filter_size = 2048
191-
hparams.moe_n1 = 128
192+
hparams.moe_num_experts = 128
192193
hparams.moe_layers = "2"
193-
hparams.moe_hidden_size = 2048
194+
return hparams
195+
196+
197+
@registry.register_hparams
198+
def attention_lm_moe_tiny():
199+
"""Cheap model for debugging.
200+
201+
Returns:
202+
an hparams object.
203+
"""
204+
hparams = attention_lm_moe_small()
205+
hparams.moe_num_experts = 32
194206
return hparams
195207

196208

@@ -233,7 +245,7 @@ def attention_lm_moe_large():
233245
hparams.hidden_size = 1024
234246
hparams.num_heads = 16
235247
hparams.filter_size = 4096
236-
hparams.moe_hidden_size = 4096
237-
hparams.moe_n1 = 128
248+
hparams.moe_hidden_sizes = "4096"
249+
hparams.moe_num_experts = 128
238250
hparams.layer_prepostprocess_dropout = 0.2
239251
return hparams

0 commit comments

Comments
 (0)