Skip to content

Commit 016db3c

Browse files
authored
[GPT-3] Add fused_multi_transformer for inference (#2500)
1 parent 7c06b02 commit 016db3c

File tree

3 files changed

+150
-67
lines changed

3 files changed

+150
-67
lines changed

examples/language_model/gpt-3/static/args.py

+7
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,13 @@ def parse_args(MODEL_CLASSES):
296296
type=str2bool,
297297
default=False,
298298
help="Whether to enable fused_attention and fused_feedforward.")
299+
parser.add_argument(
300+
"--fuse_mt",
301+
type=str2bool,
302+
default=False,
303+
help=
304+
"Whether to enable fused_multi_transformer, need open fuse. This is only used for inference"
305+
)
299306

300307
args = parser.parse_args()
301308
args.test_iters = args.eval_iters * 10

examples/language_model/gpt-3/static/modeling.py

+137-65
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def forward(self,
307307
tgt_mask=None,
308308
memory_mask=None,
309309
use_cache=False,
310-
cache=None):
310+
cache=None,
311+
time_step=None):
311312
r"""
312313
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
313314
provided, also applies layer normalization on the output of last decoder
@@ -317,38 +318,48 @@ def forward(self,
317318
new_caches = []
318319
self.checkpoints = []
319320

320-
for i, mod in enumerate(self.layers):
321-
if cache is None:
322-
if use_cache:
323-
output, new_cache = mod(output,
324-
memory,
325-
tgt_mask=tgt_mask,
326-
use_cache=use_cache,
327-
cache=cache)
328-
new_caches.append(new_cache)
329-
else:
330-
output = mod(output,
331-
memory,
332-
tgt_mask=tgt_mask,
333-
use_cache=use_cache,
334-
cache=cache)
321+
if isinstance(self.layers, nn.LayerList):
322+
for i, mod in enumerate(self.layers):
323+
if cache is None:
324+
if use_cache:
325+
output, new_cache = mod(output,
326+
memory,
327+
tgt_mask=tgt_mask,
328+
use_cache=use_cache,
329+
cache=cache)
330+
new_caches.append(new_cache)
331+
else:
332+
output = mod(output,
333+
memory,
334+
tgt_mask=tgt_mask,
335+
use_cache=use_cache,
336+
cache=cache)
335337

336-
else:
337-
if use_cache:
338-
output, new_cache = mod(output,
339-
memory,
340-
tgt_mask=tgt_mask,
341-
use_cache=use_cache,
342-
cache=cache[i])
343-
new_caches.append(new_cache)
344338
else:
345-
output = mod(output,
346-
memory,
347-
tgt_mask=tgt_mask,
348-
use_cache=use_cache,
349-
cache=cache[i])
350-
351-
self.checkpoints.append(output.name)
339+
if use_cache:
340+
output, new_cache = mod(output,
341+
memory,
342+
tgt_mask=tgt_mask,
343+
use_cache=use_cache,
344+
cache=cache[i])
345+
new_caches.append(new_cache)
346+
else:
347+
output = mod(output,
348+
memory,
349+
tgt_mask=tgt_mask,
350+
use_cache=use_cache,
351+
cache=cache[i])
352+
353+
self.checkpoints.append(output.name)
354+
else:
355+
# fused_multi_transformer
356+
output = self.layers(output,
357+
attn_mask=tgt_mask,
358+
caches=cache,
359+
time_step=time_step)
360+
if cache:
361+
new_caches = output[1]
362+
output = output[0]
352363

353364
if self.norm is not None:
354365
output = self.norm(output)
@@ -768,26 +779,61 @@ def __init__(self,
768779
type_vocab_size, self.initializer_range,
769780
topo)
770781

771-
decoder_layers = nn.LayerList()
772-
for i in range(num_hidden_layers):
773-
DecoderLayer = TransformerDecoderLayer
774-
if self.pipline_mode:
775-
DecoderLayer = paddlenlp.ops.guard('gpu:{}'.format(
776-
i // self.layer_per_stage))(TransformerDecoderLayer)
777-
decoder_layers.append(
778-
DecoderLayer(d_model=hidden_size,
779-
nhead=num_attention_heads,
780-
dim_feedforward=intermediate_size,
781-
dropout=hidden_dropout_prob,
782-
activation=hidden_act,
783-
attn_dropout=attention_probs_dropout_prob,
784-
act_dropout=hidden_dropout_prob,
785-
weight_attr=paddle.ParamAttr(
786-
initializer=nn.initializer.Normal(
787-
mean=0.0, std=self.initializer_range)),
788-
bias_attr=None,
789-
topo=topo,
790-
fuse=kwargs.get('fuse', False)))
782+
if kwargs.get('fuse_mt', False):
783+
nranks, ring_id = 1, -1
784+
if topo is not None and topo.mp_info.size > 1:
785+
nranks = topo.mp_info.size
786+
ring_id = 0
787+
788+
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
789+
mean=0.0, std=self.initializer_range))
790+
bias_attr = None
791+
decoder_layers = incubate.nn.FusedMultiTransformer(
792+
hidden_size,
793+
num_attention_heads,
794+
intermediate_size,
795+
dropout_rate=hidden_dropout_prob,
796+
activation=hidden_act,
797+
qkv_weight_attrs=_convert_param_attr_to_list(
798+
weight_attr, num_hidden_layers),
799+
qkv_bias_attrs=_convert_param_attr_to_list(
800+
bias_attr, num_hidden_layers),
801+
linear_weight_attrs=_convert_param_attr_to_list(
802+
weight_attr, num_hidden_layers),
803+
linear_bias_attrs=_convert_param_attr_to_list(
804+
bias_attr, num_hidden_layers),
805+
ffn1_weight_attrs=_convert_param_attr_to_list(
806+
weight_attr, num_hidden_layers),
807+
ffn1_bias_attrs=_convert_param_attr_to_list(
808+
bias_attr, num_hidden_layers),
809+
ffn2_weight_attrs=_convert_param_attr_to_list(
810+
weight_attr, num_hidden_layers),
811+
ffn2_bias_attrs=_convert_param_attr_to_list(
812+
bias_attr, num_hidden_layers),
813+
epsilon=1e-5,
814+
nranks=nranks,
815+
ring_id=ring_id)
816+
else:
817+
decoder_layers = nn.LayerList()
818+
for i in range(num_hidden_layers):
819+
DecoderLayer = TransformerDecoderLayer
820+
if self.pipline_mode:
821+
DecoderLayer = paddlenlp.ops.guard('gpu:{}'.format(
822+
i // self.layer_per_stage))(TransformerDecoderLayer)
823+
decoder_layers.append(
824+
DecoderLayer(d_model=hidden_size,
825+
nhead=num_attention_heads,
826+
dim_feedforward=intermediate_size,
827+
dropout=hidden_dropout_prob,
828+
activation=hidden_act,
829+
attn_dropout=attention_probs_dropout_prob,
830+
act_dropout=hidden_dropout_prob,
831+
weight_attr=paddle.ParamAttr(
832+
initializer=nn.initializer.Normal(
833+
mean=0.0, std=self.initializer_range)),
834+
bias_attr=None,
835+
topo=topo,
836+
fuse=kwargs.get('fuse', False)))
791837

792838
if self.pipline_mode:
793839
Decoder = paddlenlp.ops.guard(
@@ -809,7 +855,8 @@ def forward(self,
809855
position_ids=None,
810856
attention_mask=None,
811857
use_cache=False,
812-
cache=None):
858+
cache=None,
859+
time_step=None):
813860
self.checkpoints = []
814861
if position_ids is None:
815862
past_length = 0
@@ -832,7 +879,8 @@ def forward(self,
832879
memory=None,
833880
tgt_mask=tgt_mask,
834881
use_cache=use_cache,
835-
cache=cache)
882+
cache=cache,
883+
time_step=time_step)
836884
self.checkpoints.extend(self.decoder.checkpoints)
837885
return encoder_outputs
838886

@@ -872,12 +920,14 @@ def forward(self,
872920
attention_mask=None,
873921
masked_positions=None,
874922
use_cache=False,
875-
cache=None):
923+
cache=None,
924+
time_step=None):
876925
outputs = self.gpt(input_ids,
877926
position_ids=position_ids,
878927
attention_mask=attention_mask,
879928
use_cache=use_cache,
880-
cache=cache)
929+
cache=cache,
930+
time_step=time_step)
881931
if use_cache:
882932
encoder_outputs, cached_kvs = outputs[:2]
883933
else:
@@ -942,10 +992,17 @@ def __init__(self,
942992
self.topp = top_p
943993
self._init_gen_cache = False
944994
self.generation_caches = None
995+
# for fused_multi_transformer
996+
self.generation_time_step = None
945997
self._dtype = "float32"
946998
self._fuse = kwargs.get("fuse", False)
999+
self._fuse_mt = kwargs.get("fuse_mt", False)
9471000

9481001
def _init_generation_caches(self, src_ids):
1002+
if self._fuse and self._fuse_mt:
1003+
# output tensor is on CPUPlace
1004+
self.generation_time_step = paddle.shape(src_ids)[1]
1005+
9491006
# not fuse, return None
9501007
if self._init_gen_cache or self._fuse is False:
9511008
return self.generation_caches
@@ -956,15 +1013,23 @@ def _init_generation_caches(self, src_ids):
9561013
mp_n_head = num_heads // self.gpt.topo.mp_info.size
9571014
hidden_size = self.gpt.hidden_size
9581015
head_size = hidden_size // num_heads
1016+
seq_len = 0
1017+
if self._fuse_mt:
1018+
# FIXME(wangxi): dynamic get max_seq_len + dec_len
1019+
seq_len = 1024
9591020
for i in range(num_layers):
9601021
if self._fuse:
9611022
kv = layers.fill_constant_batch_size_like(
9621023
input=src_ids,
963-
shape=[2, -1, mp_n_head, 0, head_size],
1024+
shape=[2, -1, mp_n_head, seq_len, head_size],
9641025
dtype=self._dtype,
9651026
value=0,
9661027
output_dim_idx=1)
967-
self.generation_caches.append(TransformerDecoderLayer.Cache(kv))
1028+
if self._fuse_mt:
1029+
self.generation_caches.append(kv)
1030+
else:
1031+
self.generation_caches.append(
1032+
TransformerDecoderLayer.Cache(kv))
9681033
else:
9691034
k = layers.fill_constant_batch_size_like(
9701035
input=src_ids,
@@ -1040,12 +1105,14 @@ def model(self,
10401105
attention_mask=None,
10411106
masked_positions=None,
10421107
use_cache=False,
1043-
cache=None):
1108+
cache=None,
1109+
time_step=None):
10441110
outputs = self.gpt(input_ids,
10451111
position_ids=position_ids,
10461112
attention_mask=attention_mask,
10471113
use_cache=use_cache,
1048-
cache=cache)
1114+
cache=cache,
1115+
time_step=time_step)
10491116
if use_cache:
10501117
encoder_outputs, cached_kvs = outputs[:2]
10511118
else:
@@ -1145,11 +1212,13 @@ def forward(self, inputs, use_cache=False, cache=None):
11451212
layers.increment(x=step_idx, value=1.0, in_place=True)
11461213
layers.array_write(placehold_ids, i=step_idx, array=ids)
11471214

1148-
logits, decode_cached_kvs = self.model(pre_ids,
1149-
tgt_pos,
1150-
att_mask,
1151-
use_cache=True,
1152-
cache=cached_kvs)
1215+
logits, decode_cached_kvs = self.model(
1216+
pre_ids,
1217+
tgt_pos,
1218+
att_mask,
1219+
use_cache=True,
1220+
cache=cached_kvs,
1221+
time_step=self.generation_time_step)
11531222

11541223
logits = paddle.reshape(logits, shape=(-1, self.vocab_size))
11551224
probs = F.softmax(logits / self.temperature)
@@ -1181,10 +1250,13 @@ def forward(self, inputs, use_cache=False, cache=None):
11811250
paddle.assign(decode_mask, attention_mask)
11821251
for i in range(len(decode_cached_kvs)):
11831252
if self._fuse:
1184-
paddle.assign(decode_cached_kvs[i].kv, cached_kvs[i].kv)
1253+
if not self._fuse_mt:
1254+
paddle.assign(decode_cached_kvs[i].kv, cached_kvs[i].kv)
11851255
else:
11861256
paddle.assign(decode_cached_kvs[i].k, cached_kvs[i].k)
11871257
paddle.assign(decode_cached_kvs[i].v, cached_kvs[i].v)
1258+
if self.generation_time_step:
1259+
paddle.increment(self.generation_time_step, value=1.0)
11881260

11891261
ids, _ = layers.tensor_array_to_tensor(ids)
11901262
return ids

examples/language_model/gpt-3/static/run_generation.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def do_generation(args):
175175
startup_program = paddle.static.default_startup_program()
176176
with paddle.static.program_guard(main_program, startup_program):
177177
with paddle.utils.unique_name.guard():
178-
with paddle.static.device_guard('gpu:0'):
178+
with paddle.static.device_guard(
179+
'gpu:0' if topo.pp_info.size > 1 else None):
179180
feeds = create_data_holder(args)
180181
tokenizer = tokenizer_class.from_pretrained(
181182
args.model_name_or_path)
@@ -202,6 +203,7 @@ def do_generation(args):
202203
"attention_probs_dropout_prob"] = args.attention_probs_dropout_prob
203204
model_config["topo"] = topo
204205
model_config["fuse"] = args.fuse
206+
model_config["fuse_mt"] = args.fuse_mt
205207
model = GPTForGeneration(
206208
GPTModel(**model_config),
207209
max_length=args.max_dec_len,
@@ -210,7 +212,8 @@ def do_generation(args):
210212
top_k=args.topk,
211213
top_p=args.topp,
212214
eos_id=eos_id,
213-
fuse=args.fuse)
215+
fuse=args.fuse,
216+
fuse_mt=args.fuse_mt)
214217
else:
215218
logger.error("No checkpoint load.")
216219
model.eval()
@@ -221,6 +224,7 @@ def do_generation(args):
221224
exe = paddle.static.Executor(place)
222225
exe.run(startup_program)
223226
main_program = main_program.clone(for_test=True)
227+
#debug_program('main_program', main_program)
224228

225229
model_urls = model.pretrained_resource_files_map['model_state']
226230
model_path = args.model_name_or_path

0 commit comments

Comments
 (0)