@@ -307,7 +307,8 @@ def forward(self,
307
307
tgt_mask = None ,
308
308
memory_mask = None ,
309
309
use_cache = False ,
310
- cache = None ):
310
+ cache = None ,
311
+ time_step = None ):
311
312
r"""
312
313
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
313
314
provided, also applies layer normalization on the output of last decoder
@@ -317,38 +318,48 @@ def forward(self,
317
318
new_caches = []
318
319
self .checkpoints = []
319
320
320
- for i , mod in enumerate (self .layers ):
321
- if cache is None :
322
- if use_cache :
323
- output , new_cache = mod (output ,
324
- memory ,
325
- tgt_mask = tgt_mask ,
326
- use_cache = use_cache ,
327
- cache = cache )
328
- new_caches .append (new_cache )
329
- else :
330
- output = mod (output ,
331
- memory ,
332
- tgt_mask = tgt_mask ,
333
- use_cache = use_cache ,
334
- cache = cache )
321
+ if isinstance (self .layers , nn .LayerList ):
322
+ for i , mod in enumerate (self .layers ):
323
+ if cache is None :
324
+ if use_cache :
325
+ output , new_cache = mod (output ,
326
+ memory ,
327
+ tgt_mask = tgt_mask ,
328
+ use_cache = use_cache ,
329
+ cache = cache )
330
+ new_caches .append (new_cache )
331
+ else :
332
+ output = mod (output ,
333
+ memory ,
334
+ tgt_mask = tgt_mask ,
335
+ use_cache = use_cache ,
336
+ cache = cache )
335
337
336
- else :
337
- if use_cache :
338
- output , new_cache = mod (output ,
339
- memory ,
340
- tgt_mask = tgt_mask ,
341
- use_cache = use_cache ,
342
- cache = cache [i ])
343
- new_caches .append (new_cache )
344
338
else :
345
- output = mod (output ,
346
- memory ,
347
- tgt_mask = tgt_mask ,
348
- use_cache = use_cache ,
349
- cache = cache [i ])
350
-
351
- self .checkpoints .append (output .name )
339
+ if use_cache :
340
+ output , new_cache = mod (output ,
341
+ memory ,
342
+ tgt_mask = tgt_mask ,
343
+ use_cache = use_cache ,
344
+ cache = cache [i ])
345
+ new_caches .append (new_cache )
346
+ else :
347
+ output = mod (output ,
348
+ memory ,
349
+ tgt_mask = tgt_mask ,
350
+ use_cache = use_cache ,
351
+ cache = cache [i ])
352
+
353
+ self .checkpoints .append (output .name )
354
+ else :
355
+ # fused_multi_transformer
356
+ output = self .layers (output ,
357
+ attn_mask = tgt_mask ,
358
+ caches = cache ,
359
+ time_step = time_step )
360
+ if cache :
361
+ new_caches = output [1 ]
362
+ output = output [0 ]
352
363
353
364
if self .norm is not None :
354
365
output = self .norm (output )
@@ -768,26 +779,61 @@ def __init__(self,
768
779
type_vocab_size , self .initializer_range ,
769
780
topo )
770
781
771
- decoder_layers = nn .LayerList ()
772
- for i in range (num_hidden_layers ):
773
- DecoderLayer = TransformerDecoderLayer
774
- if self .pipline_mode :
775
- DecoderLayer = paddlenlp .ops .guard ('gpu:{}' .format (
776
- i // self .layer_per_stage ))(TransformerDecoderLayer )
777
- decoder_layers .append (
778
- DecoderLayer (d_model = hidden_size ,
779
- nhead = num_attention_heads ,
780
- dim_feedforward = intermediate_size ,
781
- dropout = hidden_dropout_prob ,
782
- activation = hidden_act ,
783
- attn_dropout = attention_probs_dropout_prob ,
784
- act_dropout = hidden_dropout_prob ,
785
- weight_attr = paddle .ParamAttr (
786
- initializer = nn .initializer .Normal (
787
- mean = 0.0 , std = self .initializer_range )),
788
- bias_attr = None ,
789
- topo = topo ,
790
- fuse = kwargs .get ('fuse' , False )))
782
+ if kwargs .get ('fuse_mt' , False ):
783
+ nranks , ring_id = 1 , - 1
784
+ if topo is not None and topo .mp_info .size > 1 :
785
+ nranks = topo .mp_info .size
786
+ ring_id = 0
787
+
788
+ weight_attr = paddle .ParamAttr (initializer = nn .initializer .Normal (
789
+ mean = 0.0 , std = self .initializer_range ))
790
+ bias_attr = None
791
+ decoder_layers = incubate .nn .FusedMultiTransformer (
792
+ hidden_size ,
793
+ num_attention_heads ,
794
+ intermediate_size ,
795
+ dropout_rate = hidden_dropout_prob ,
796
+ activation = hidden_act ,
797
+ qkv_weight_attrs = _convert_param_attr_to_list (
798
+ weight_attr , num_hidden_layers ),
799
+ qkv_bias_attrs = _convert_param_attr_to_list (
800
+ bias_attr , num_hidden_layers ),
801
+ linear_weight_attrs = _convert_param_attr_to_list (
802
+ weight_attr , num_hidden_layers ),
803
+ linear_bias_attrs = _convert_param_attr_to_list (
804
+ bias_attr , num_hidden_layers ),
805
+ ffn1_weight_attrs = _convert_param_attr_to_list (
806
+ weight_attr , num_hidden_layers ),
807
+ ffn1_bias_attrs = _convert_param_attr_to_list (
808
+ bias_attr , num_hidden_layers ),
809
+ ffn2_weight_attrs = _convert_param_attr_to_list (
810
+ weight_attr , num_hidden_layers ),
811
+ ffn2_bias_attrs = _convert_param_attr_to_list (
812
+ bias_attr , num_hidden_layers ),
813
+ epsilon = 1e-5 ,
814
+ nranks = nranks ,
815
+ ring_id = ring_id )
816
+ else :
817
+ decoder_layers = nn .LayerList ()
818
+ for i in range (num_hidden_layers ):
819
+ DecoderLayer = TransformerDecoderLayer
820
+ if self .pipline_mode :
821
+ DecoderLayer = paddlenlp .ops .guard ('gpu:{}' .format (
822
+ i // self .layer_per_stage ))(TransformerDecoderLayer )
823
+ decoder_layers .append (
824
+ DecoderLayer (d_model = hidden_size ,
825
+ nhead = num_attention_heads ,
826
+ dim_feedforward = intermediate_size ,
827
+ dropout = hidden_dropout_prob ,
828
+ activation = hidden_act ,
829
+ attn_dropout = attention_probs_dropout_prob ,
830
+ act_dropout = hidden_dropout_prob ,
831
+ weight_attr = paddle .ParamAttr (
832
+ initializer = nn .initializer .Normal (
833
+ mean = 0.0 , std = self .initializer_range )),
834
+ bias_attr = None ,
835
+ topo = topo ,
836
+ fuse = kwargs .get ('fuse' , False )))
791
837
792
838
if self .pipline_mode :
793
839
Decoder = paddlenlp .ops .guard (
@@ -809,7 +855,8 @@ def forward(self,
809
855
position_ids = None ,
810
856
attention_mask = None ,
811
857
use_cache = False ,
812
- cache = None ):
858
+ cache = None ,
859
+ time_step = None ):
813
860
self .checkpoints = []
814
861
if position_ids is None :
815
862
past_length = 0
@@ -832,7 +879,8 @@ def forward(self,
832
879
memory = None ,
833
880
tgt_mask = tgt_mask ,
834
881
use_cache = use_cache ,
835
- cache = cache )
882
+ cache = cache ,
883
+ time_step = time_step )
836
884
self .checkpoints .extend (self .decoder .checkpoints )
837
885
return encoder_outputs
838
886
@@ -872,12 +920,14 @@ def forward(self,
872
920
attention_mask = None ,
873
921
masked_positions = None ,
874
922
use_cache = False ,
875
- cache = None ):
923
+ cache = None ,
924
+ time_step = None ):
876
925
outputs = self .gpt (input_ids ,
877
926
position_ids = position_ids ,
878
927
attention_mask = attention_mask ,
879
928
use_cache = use_cache ,
880
- cache = cache )
929
+ cache = cache ,
930
+ time_step = time_step )
881
931
if use_cache :
882
932
encoder_outputs , cached_kvs = outputs [:2 ]
883
933
else :
@@ -942,10 +992,17 @@ def __init__(self,
942
992
self .topp = top_p
943
993
self ._init_gen_cache = False
944
994
self .generation_caches = None
995
+ # for fused_multi_transformer
996
+ self .generation_time_step = None
945
997
self ._dtype = "float32"
946
998
self ._fuse = kwargs .get ("fuse" , False )
999
+ self ._fuse_mt = kwargs .get ("fuse_mt" , False )
947
1000
948
1001
def _init_generation_caches (self , src_ids ):
1002
+ if self ._fuse and self ._fuse_mt :
1003
+ # output tensor is on CPUPlace
1004
+ self .generation_time_step = paddle .shape (src_ids )[1 ]
1005
+
949
1006
# not fuse, return None
950
1007
if self ._init_gen_cache or self ._fuse is False :
951
1008
return self .generation_caches
@@ -956,15 +1013,23 @@ def _init_generation_caches(self, src_ids):
956
1013
mp_n_head = num_heads // self .gpt .topo .mp_info .size
957
1014
hidden_size = self .gpt .hidden_size
958
1015
head_size = hidden_size // num_heads
1016
+ seq_len = 0
1017
+ if self ._fuse_mt :
1018
+ # FIXME(wangxi): dynamic get max_seq_len + dec_len
1019
+ seq_len = 1024
959
1020
for i in range (num_layers ):
960
1021
if self ._fuse :
961
1022
kv = layers .fill_constant_batch_size_like (
962
1023
input = src_ids ,
963
- shape = [2 , - 1 , mp_n_head , 0 , head_size ],
1024
+ shape = [2 , - 1 , mp_n_head , seq_len , head_size ],
964
1025
dtype = self ._dtype ,
965
1026
value = 0 ,
966
1027
output_dim_idx = 1 )
967
- self .generation_caches .append (TransformerDecoderLayer .Cache (kv ))
1028
+ if self ._fuse_mt :
1029
+ self .generation_caches .append (kv )
1030
+ else :
1031
+ self .generation_caches .append (
1032
+ TransformerDecoderLayer .Cache (kv ))
968
1033
else :
969
1034
k = layers .fill_constant_batch_size_like (
970
1035
input = src_ids ,
@@ -1040,12 +1105,14 @@ def model(self,
1040
1105
attention_mask = None ,
1041
1106
masked_positions = None ,
1042
1107
use_cache = False ,
1043
- cache = None ):
1108
+ cache = None ,
1109
+ time_step = None ):
1044
1110
outputs = self .gpt (input_ids ,
1045
1111
position_ids = position_ids ,
1046
1112
attention_mask = attention_mask ,
1047
1113
use_cache = use_cache ,
1048
- cache = cache )
1114
+ cache = cache ,
1115
+ time_step = time_step )
1049
1116
if use_cache :
1050
1117
encoder_outputs , cached_kvs = outputs [:2 ]
1051
1118
else :
@@ -1145,11 +1212,13 @@ def forward(self, inputs, use_cache=False, cache=None):
1145
1212
layers .increment (x = step_idx , value = 1.0 , in_place = True )
1146
1213
layers .array_write (placehold_ids , i = step_idx , array = ids )
1147
1214
1148
- logits , decode_cached_kvs = self .model (pre_ids ,
1149
- tgt_pos ,
1150
- att_mask ,
1151
- use_cache = True ,
1152
- cache = cached_kvs )
1215
+ logits , decode_cached_kvs = self .model (
1216
+ pre_ids ,
1217
+ tgt_pos ,
1218
+ att_mask ,
1219
+ use_cache = True ,
1220
+ cache = cached_kvs ,
1221
+ time_step = self .generation_time_step )
1153
1222
1154
1223
logits = paddle .reshape (logits , shape = (- 1 , self .vocab_size ))
1155
1224
probs = F .softmax (logits / self .temperature )
@@ -1181,10 +1250,13 @@ def forward(self, inputs, use_cache=False, cache=None):
1181
1250
paddle .assign (decode_mask , attention_mask )
1182
1251
for i in range (len (decode_cached_kvs )):
1183
1252
if self ._fuse :
1184
- paddle .assign (decode_cached_kvs [i ].kv , cached_kvs [i ].kv )
1253
+ if not self ._fuse_mt :
1254
+ paddle .assign (decode_cached_kvs [i ].kv , cached_kvs [i ].kv )
1185
1255
else :
1186
1256
paddle .assign (decode_cached_kvs [i ].k , cached_kvs [i ].k )
1187
1257
paddle .assign (decode_cached_kvs [i ].v , cached_kvs [i ].v )
1258
+ if self .generation_time_step :
1259
+ paddle .increment (self .generation_time_step , value = 1.0 )
1188
1260
1189
1261
ids , _ = layers .tensor_array_to_tensor (ids )
1190
1262
return ids
0 commit comments