Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 95ee9e5

Browse files
nshazeerRyan Sepassi
authored andcommitted
added transformer_moe - a transformer model with mixtures-of-experts.
PiperOrigin-RevId: 164190826
1 parent 34a961f commit 95ee9e5

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

tensor2tensor/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@
3737
from tensor2tensor.models import slicenet
3838
from tensor2tensor.models import transformer
3939
from tensor2tensor.models import transformer_alternative
40+
from tensor2tensor.models import transformer_moe
4041
from tensor2tensor.models import xception
4142
# pylint: enable=unused-import
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""transformer (attention seq-seq model) with mixtures of experts.
17+
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
# Dependency imports
25+
26+
from six.moves import xrange # pylint: disable=redefined-builtin
27+
28+
from tensor2tensor.layers import common_attention
29+
from tensor2tensor.layers import common_hparams
30+
from tensor2tensor.layers import common_layers
31+
from tensor2tensor.models import transformer
32+
from tensor2tensor.utils import registry
33+
from tensor2tensor.utils import t2t_model
34+
35+
import tensorflow as tf
36+
37+
38+
@registry.register_model
39+
class TransformerMoe(t2t_model.T2TModel):
40+
"""Attention net. See file docstring."""
41+
42+
def model_fn_body_sharded(self, sharded_features):
43+
hparams = self._hparams
44+
dp = self._data_parallelism
45+
targets = sharded_features["targets"]
46+
inputs = sharded_features["inputs"]
47+
target_space = sharded_features["target_space_id"]
48+
49+
inputs = dp(common_layers.flatten4d3d, inputs)
50+
targets = dp(common_layers.flatten4d3d, targets)
51+
52+
(encoder_input, encoder_self_attention_bias,
53+
encoder_decoder_attention_bias) = dp(
54+
transformer.transformer_prepare_encoder,
55+
inputs, target_space, hparams)
56+
(decoder_input, decoder_self_attention_bias) = dp(
57+
transformer.transformer_prepare_decoder, targets, hparams)
58+
residual_fn = transformer.get_residual_fn(hparams)
59+
encoder_input = dp(tf.nn.dropout, encoder_input,
60+
1.0 - hparams.residual_dropout)
61+
decoder_input = dp(tf.nn.dropout, decoder_input,
62+
1.0 - hparams.residual_dropout)
63+
extra_loss = 0
64+
x = encoder_input
65+
for layer in xrange(hparams.num_hidden_layers):
66+
with tf.variable_scope("encoder_layer_%d" % layer):
67+
with tf.variable_scope("encoder_self_attention"):
68+
y = dp(
69+
common_attention.multihead_attention,
70+
x,
71+
None,
72+
encoder_self_attention_bias,
73+
hparams.attention_key_channels or hparams.hidden_size,
74+
hparams.attention_value_channels or hparams.hidden_size,
75+
hparams.hidden_size,
76+
hparams.num_heads,
77+
hparams.attention_dropout)
78+
x = dp(residual_fn, x, y)
79+
with tf.variable_scope("ffn"):
80+
if str(layer) in hparams.moe_layers_encoder.split(","):
81+
y, loss = common_layers.moe_layer(
82+
dp, self._ps_devices, x,
83+
hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
84+
hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1,
85+
hparams.moe_n2, hparams.moe_loss_coef)
86+
extra_loss += loss
87+
else:
88+
y = dp(
89+
common_layers.conv_hidden_relu,
90+
x,
91+
hparams.filter_size,
92+
hparams.hidden_size,
93+
dropout=hparams.relu_dropout)
94+
x = dp(residual_fn, x, y)
95+
encoder_output = x
96+
x = decoder_input
97+
for layer in xrange(hparams.num_hidden_layers):
98+
with tf.variable_scope("decoder_layer_%d" % layer):
99+
with tf.variable_scope("decoder_self_attention"):
100+
y = dp(
101+
common_attention.multihead_attention,
102+
x,
103+
None,
104+
decoder_self_attention_bias,
105+
hparams.attention_key_channels or hparams.hidden_size,
106+
hparams.attention_value_channels or hparams.hidden_size,
107+
hparams.hidden_size,
108+
hparams.num_heads,
109+
hparams.attention_dropout)
110+
x = dp(residual_fn, x, y)
111+
with tf.variable_scope("encoder_decoder_attention"):
112+
y = dp(
113+
common_attention.multihead_attention,
114+
x,
115+
encoder_output,
116+
encoder_decoder_attention_bias,
117+
hparams.attention_key_channels or hparams.hidden_size,
118+
hparams.attention_value_channels or hparams.hidden_size,
119+
hparams.hidden_size,
120+
hparams.num_heads,
121+
hparams.attention_dropout)
122+
x = dp(residual_fn, x, y)
123+
with tf.variable_scope("ffn"):
124+
if str(layer) in hparams.moe_layers_decoder.split(","):
125+
y, loss = common_layers.moe_layer(
126+
dp, self._ps_devices, x,
127+
hparams.mode == tf.contrib.learn.ModeKeys.TRAIN,
128+
hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1,
129+
hparams.moe_n2, hparams.moe_loss_coef)
130+
extra_loss += loss
131+
else:
132+
y = dp(
133+
common_layers.conv_hidden_relu,
134+
x,
135+
hparams.filter_size,
136+
hparams.hidden_size,
137+
dropout=hparams.relu_dropout)
138+
x = dp(residual_fn, x, y)
139+
decoder_output = dp(tf.expand_dims, x, 2)
140+
return decoder_output, extra_loss
141+
142+
143+
@registry.register_hparams
144+
def transformer_moe_base():
145+
"""Set of hyperparameters."""
146+
hparams = common_hparams.basic_params1()
147+
hparams.norm_type = "layer"
148+
hparams.hidden_size = 512
149+
hparams.batch_size = 4096
150+
hparams.max_length = 2001
151+
hparams.max_input_seq_length = 2000
152+
hparams.max_target_seq_length = 2000
153+
hparams.dropout = 0.0
154+
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
155+
hparams.optimizer_adam_epsilon = 1e-9
156+
hparams.learning_rate_decay_scheme = "noam"
157+
hparams.learning_rate = 0.1
158+
hparams.learning_rate_warmup_steps = 4000
159+
hparams.initializer_gain = 1.0
160+
hparams.num_hidden_layers = 5
161+
hparams.initializer = "uniform_unit_scaling"
162+
hparams.weight_decay = 0.0
163+
hparams.optimizer_adam_beta1 = 0.9
164+
hparams.optimizer_adam_beta2 = 0.98
165+
hparams.num_sampled_classes = 0
166+
hparams.label_smoothing = 0.0
167+
hparams.shared_embedding_and_softmax_weights = int(True)
168+
169+
hparams.add_hparam("filter_size", 2048) # Add new ones like this.
170+
# attention-related flags
171+
hparams.add_hparam("num_heads", 8)
172+
hparams.add_hparam("attention_key_channels", 0)
173+
hparams.add_hparam("attention_value_channels", 0)
174+
hparams.add_hparam("ffn_layer", "conv_hidden_relu")
175+
hparams.add_hparam("parameter_attention_key_channels", 0)
176+
hparams.add_hparam("parameter_attention_value_channels", 0)
177+
# All hyperparameters ending in "dropout" are automatically set to 0.0
178+
# when not in training mode.
179+
hparams.add_hparam("attention_dropout", 0.0)
180+
hparams.add_hparam("relu_dropout", 0.0)
181+
hparams.add_hparam("residual_dropout", 0.1)
182+
hparams.add_hparam("pos", "timing") # timing, none
183+
hparams.add_hparam("nbr_decoder_problems", 1)
184+
hparams.add_hparam("proximity_bias", int(False))
185+
# FLAGS RELATED TO MIXTURE-OF-EXPERTS
186+
# comma-separated list of layer numbers.
187+
# At each of these layers, we replace the ffn with a mixture of experts.
188+
hparams.add_hparam("moe_layers_encoder", "2")
189+
hparams.add_hparam("moe_layers_decoder", "2")
190+
# If moe_n2 is None, then use a flat MoE with moe_n1 experts.
191+
# If moe_n2 is an integer, then use a hierarchical MoE
192+
# consisting of moe_n1 groups of moe_n2 experts each.
193+
hparams.add_hparam("moe_n1", 32)
194+
hparams.add_hparam("moe_n2", 0)
195+
hparams.add_hparam("moe_hidden_size", 2048)
196+
hparams.add_hparam("moe_loss_coef", 1e-2)
197+
return hparams
198+
199+
200+
@registry.register_hparams
201+
def transformer_no_moe():
202+
"""Without the mixture of experts (for comparison)."""
203+
hparams = transformer_moe_base()
204+
hparams.moe_layers_encoder = ""
205+
hparams.moe_layers_decoder = ""
206+
return hparams
207+
208+
209+
@registry.register_hparams
210+
def transformer_moe_1b():
211+
"""1-billion parameter model - requires multi-gpu sync training."""
212+
hparams = transformer_moe_base()
213+
hparams.moe_n1 = 128
214+
hparams.moe_layers_encoder = "1,3"
215+
hparams.moe_layers_decoder = "1,3"
216+
return hparams

0 commit comments

Comments
 (0)