@@ -100,7 +100,14 @@ class FasterTransformer(TransformerModel):
100
100
for details. Bigger `diversity_rate` would lead to more diversity.
101
101
if `diversity_rate == 0` is equivalent to naive BeamSearch. Default
102
102
to 0 if not set.
103
- use_fp16_decoding(bool, optional): Whether to use fp16 for decoding.
103
+ use_fp16_decoding(bool, optional):
104
+ Whether to use fp16 for decoding.
105
+ enable_faster_encoder(bool, optional):
106
+ Whether to use the faster version of encoder. This is experimental option for now.
107
+ Defaults to False.
108
+ use_fp16_encoder(bool, optional):
109
+ Whether to use fp16 for encoder. Only works when enable_faster_encoder is True.
110
+ Defaults to False.
104
111
rel_len(bool, optional):
105
112
Indicating whether `max_out_len` in is the length relative to that
106
113
of source text. Only works in `v2` temporarily. It is suggest to set
@@ -135,6 +142,8 @@ def __init__(self,
135
142
diversity_rate = 0.0 ,
136
143
decoding_lib = None ,
137
144
use_fp16_decoding = False ,
145
+ enable_faster_encoder = False ,
146
+ use_fp16_encoder = False ,
138
147
rel_len = False ,
139
148
alpha = 0.6 ):
140
149
# if decoding_lib is None:
@@ -154,6 +163,8 @@ def __init__(self,
154
163
self .diversity_rate = args .pop ("diversity_rate" )
155
164
self .decoding_lib = args .pop ("decoding_lib" )
156
165
self .use_fp16_decoding = args .pop ("use_fp16_decoding" )
166
+ self .enable_faster_encoder = args .pop ("enable_faster_encoder" )
167
+ self .use_fp16_encoder = args .pop ("use_fp16_encoder" )
157
168
self .rel_len = args .pop ("rel_len" )
158
169
self .alpha = args .pop ("alpha" )
159
170
self .dropout = dropout
@@ -164,6 +175,13 @@ def __init__(self,
164
175
self .max_length = max_length
165
176
super (FasterTransformer , self ).__init__ (** args )
166
177
178
+ if self .enable_faster_encoder :
179
+ logger .warning (
180
+ "enable_faster_encoder is an experimental option and subject to change."
181
+ )
182
+ elif self .use_fp16_encoder :
183
+ self .use_fp16_encoder = False
184
+
167
185
self .decoding_linear = nn .Linear (
168
186
in_features = d_model , out_features = trg_vocab_size )
169
187
@@ -210,10 +228,16 @@ def forward(self, src_word, trg_word=None):
210
228
enc_input = F .dropout (
211
229
src_emb , p = self .dropout ,
212
230
training = False ) if self .dropout else src_emb
231
+
232
+ if self .enable_faster_encoder and self .use_fp16_encoder :
233
+ enc_input = paddle .cast (enc_input , dtype = "float16" )
234
+
213
235
enc_output = self .transformer .encoder (enc_input , src_slf_attn_bias )
214
236
215
- if self .use_fp16_decoding :
237
+ if self .use_fp16_decoding and enc_output . dtype != paddle . float16 :
216
238
enc_output = paddle .cast (enc_output , dtype = "float16" )
239
+ elif not self .use_fp16_decoding and enc_output .dtype != paddle .float32 :
240
+ enc_output = paddle .cast (enc_output , dtype = "float32" )
217
241
218
242
mem_seq_lens = paddle .sum (paddle .cast (
219
243
src_word != self .bos_id , dtype = "int32" ),
@@ -1104,12 +1128,12 @@ def forward(self,
1104
1128
forced_eos_token_id = None ,
1105
1129
** model_kwargs ):
1106
1130
1107
- self .encoder = enable_faster_encoder (self .encoder , need_build = False )
1108
1131
if encoder_output is None :
1132
+ self .encoder = enable_faster_encoder (self .encoder , need_build = False )
1109
1133
assert input_ids is not None , "You have to specify either input_ids or encoder_output."
1110
1134
encoder_output = self .prepare_encoder_decoder_kwargs_for_generation (
1111
1135
input_ids , model_kwargs )["encoder_output" ]
1112
- self .encoder = disable_faster_encoder (self .encoder )
1136
+ self .encoder = disable_faster_encoder (self .encoder )
1113
1137
if seq_len is None :
1114
1138
assert input_ids is not None , "You have to specify either input_ids when generating seq_len."
1115
1139
seq_len = paddle .sum (paddle .cast (
@@ -1207,12 +1231,13 @@ def forward(self,
1207
1231
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else getattr (
1208
1232
self ._model , 'decoder_start_token_id' , None )
1209
1233
1210
- self . encoder = enable_faster_encoder ( self . encoder , need_build = False )
1234
+ #(gongenlei) Not enable_faster_encoder temporarily
1211
1235
if encoder_output is None :
1236
+ self .encoder = enable_faster_encoder (self .encoder , need_build = False )
1212
1237
assert input_ids is not None , "You have to specify either input_ids or encoder_output."
1213
1238
encoder_output = self .prepare_encoder_decoder_kwargs_for_generation (
1214
1239
input_ids , model_kwargs )["encoder_output" ]
1215
- self .encoder = disable_faster_encoder (self .encoder )
1240
+ self .encoder = disable_faster_encoder (self .encoder )
1216
1241
batch_size = paddle .shape (encoder_output )[0 ]
1217
1242
if seq_len is None :
1218
1243
assert input_ids is not None , "You have to specify either input_ids when generating seq_len."
0 commit comments