Skip to content

Commit 0720afe

Browse files
FrostMLguoshengCS
andauthored
Transformer supports ft faster encoder (#1430)
* transformer support ft encoder Co-authored-by: Guo Sheng <whucsgs@163.com>
1 parent 3a24b61 commit 0720afe

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

paddlenlp/ops/faster_transformer/sample/encoder_decoding_sample.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pprint import pprint
2525

2626
from paddlenlp.ops import FasterTransformer
27+
from paddlenlp.ops import enable_faster_encoder
2728

2829
from paddlenlp.utils.log import logger
2930
from paddlenlp.data import Pad
@@ -33,7 +34,7 @@ def parse_args():
3334
parser = argparse.ArgumentParser()
3435
parser.add_argument(
3536
"--config",
36-
default="./sample/config/decoding.sample.yaml",
37+
default="./faster_transformer/sample/config/decoding.sample.yaml",
3738
type=str,
3839
help="Path of the config file. ")
3940
parser.add_argument(
@@ -45,6 +46,15 @@ def parse_args():
4546
"--use_fp16_decoding",
4647
action="store_true",
4748
help="Whether to use fp16 decoding to predict. ")
49+
parser.add_argument(
50+
"--enable_faster_encoder",
51+
action="store_true",
52+
help="Whether to use faster version encoder to predict. This is experimental option for now. "
53+
)
54+
parser.add_argument(
55+
"--use_fp16_encoder",
56+
action="store_true",
57+
help="Whether to use fp16 encoder to predict. ")
4858
args = parser.parse_args()
4959
return args
5060

@@ -69,7 +79,7 @@ def generate_src_word(batch_size, vocab_size, max_length, eos_idx, pad_idx):
6979

7080
def do_predict(args):
7181
place = "gpu"
72-
paddle.set_device(place)
82+
place = paddle.set_device(place)
7383

7484
# Define model
7585
transformer = FasterTransformer(
@@ -91,11 +101,17 @@ def do_predict(args):
91101
topp=args.topp,
92102
max_out_len=args.max_out_len,
93103
decoding_lib=args.decoding_lib,
94-
use_fp16_decoding=args.use_fp16_decoding)
104+
use_fp16_decoding=args.use_fp16_decoding,
105+
enable_faster_encoder=args.enable_faster_encoder,
106+
use_fp16_encoder=args.use_fp16_encoder)
95107

96108
# Set evaluate mode
97109
transformer.eval()
98110

111+
if args.enable_faster_encoder:
112+
transformer = enable_faster_encoder(
113+
transformer, need_build=False, use_fp16=args.use_fp16_encoder)
114+
99115
src_word = generate_src_word(
100116
batch_size=args.infer_batch_size,
101117
vocab_size=args.src_vocab_size,
@@ -107,10 +123,10 @@ def do_predict(args):
107123
for i in range(100):
108124
# For warmup.
109125
if 50 == i:
110-
paddle.device.cuda.synchronize()
126+
paddle.device.cuda.synchronize(place)
111127
start = time.time()
112128
transformer(src_word=src_word)
113-
paddle.device.cuda.synchronize()
129+
paddle.device.cuda.synchronize(place)
114130
logger.info("Average test time for encoder-decoding is %f ms" % (
115131
(time.time() - start) / 50 * 1000))
116132

@@ -120,8 +136,10 @@ def do_predict(args):
120136
yaml_file = ARGS.config
121137
with open(yaml_file, 'rt') as f:
122138
args = AttrDict(yaml.safe_load(f))
123-
pprint(args)
124139
args.decoding_lib = ARGS.decoding_lib
125140
args.use_fp16_decoding = ARGS.use_fp16_decoding
141+
args.enable_faster_encoder = ARGS.enable_faster_encoder
142+
args.use_fp16_encoder = ARGS.use_fp16_encoder
143+
pprint(args)
126144

127145
do_predict(args)

paddlenlp/ops/faster_transformer/transformer/encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def encoder_forward(self, src, src_mask=None, cache=None):
241241
def enable_faster_encoder(self,
242242
need_build=True,
243243
use_fp16=False,
244-
decoding_lib=None):
244+
encoder_lib=None):
245245
"""
246246
Compiles fusion encoder operator intergrated FasterTransformer using the
247247
method of JIT(Just-In-Time) and replaces the `forward` function of
@@ -285,13 +285,13 @@ def init_func(layer):
285285
try:
286286
# Pass decoding lib to prevent re-building encoder.
287287
# Todo: check weather decoding lib have contained encoder or not.
288-
if decoding_lib is not None:
289-
load_op_meta_info_and_register_op(decoding_lib)
288+
if encoder_lib is not None:
289+
load_op_meta_info_and_register_op(encoder_lib)
290290
else:
291291
load("FasterTransformer", verbose=True)
292292
except Exception:
293293
logger.warning(
294-
"Exception occurs when using FasterTransformer. " \
294+
"Exception occurs when using FasterEncoder. " \
295295
"The original forward will be involved. ")
296296
return self
297297
for layer in self.children():

paddlenlp/ops/faster_transformer/transformer/faster_transformer.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,14 @@ class FasterTransformer(TransformerModel):
100100
for details. Bigger `diversity_rate` would lead to more diversity.
101101
if `diversity_rate == 0` is equivalent to naive BeamSearch. Default
102102
to 0 if not set.
103-
use_fp16_decoding(bool, optional): Whether to use fp16 for decoding.
103+
use_fp16_decoding(bool, optional):
104+
Whether to use fp16 for decoding.
105+
enable_faster_encoder(bool, optional):
106+
Whether to use the faster version of encoder. This is experimental option for now.
107+
Defaults to False.
108+
use_fp16_encoder(bool, optional):
109+
Whether to use fp16 for encoder. Only works when enable_faster_encoder is True.
110+
Defaults to False.
104111
rel_len(bool, optional):
105112
Indicating whether `max_out_len` in is the length relative to that
106113
of source text. Only works in `v2` temporarily. It is suggest to set
@@ -135,6 +142,8 @@ def __init__(self,
135142
diversity_rate=0.0,
136143
decoding_lib=None,
137144
use_fp16_decoding=False,
145+
enable_faster_encoder=False,
146+
use_fp16_encoder=False,
138147
rel_len=False,
139148
alpha=0.6):
140149
# if decoding_lib is None:
@@ -154,6 +163,8 @@ def __init__(self,
154163
self.diversity_rate = args.pop("diversity_rate")
155164
self.decoding_lib = args.pop("decoding_lib")
156165
self.use_fp16_decoding = args.pop("use_fp16_decoding")
166+
self.enable_faster_encoder = args.pop("enable_faster_encoder")
167+
self.use_fp16_encoder = args.pop("use_fp16_encoder")
157168
self.rel_len = args.pop("rel_len")
158169
self.alpha = args.pop("alpha")
159170
self.dropout = dropout
@@ -164,6 +175,13 @@ def __init__(self,
164175
self.max_length = max_length
165176
super(FasterTransformer, self).__init__(**args)
166177

178+
if self.enable_faster_encoder:
179+
logger.warning(
180+
"enable_faster_encoder is an experimental option and subject to change."
181+
)
182+
elif self.use_fp16_encoder:
183+
self.use_fp16_encoder = False
184+
167185
self.decoding_linear = nn.Linear(
168186
in_features=d_model, out_features=trg_vocab_size)
169187

@@ -210,10 +228,16 @@ def forward(self, src_word, trg_word=None):
210228
enc_input = F.dropout(
211229
src_emb, p=self.dropout,
212230
training=False) if self.dropout else src_emb
231+
232+
if self.enable_faster_encoder and self.use_fp16_encoder:
233+
enc_input = paddle.cast(enc_input, dtype="float16")
234+
213235
enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias)
214236

215-
if self.use_fp16_decoding:
237+
if self.use_fp16_decoding and enc_output.dtype != paddle.float16:
216238
enc_output = paddle.cast(enc_output, dtype="float16")
239+
elif not self.use_fp16_decoding and enc_output.dtype != paddle.float32:
240+
enc_output = paddle.cast(enc_output, dtype="float32")
217241

218242
mem_seq_lens = paddle.sum(paddle.cast(
219243
src_word != self.bos_id, dtype="int32"),
@@ -1104,12 +1128,12 @@ def forward(self,
11041128
forced_eos_token_id=None,
11051129
**model_kwargs):
11061130

1107-
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
11081131
if encoder_output is None:
1132+
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
11091133
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
11101134
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
11111135
input_ids, model_kwargs)["encoder_output"]
1112-
self.encoder = disable_faster_encoder(self.encoder)
1136+
self.encoder = disable_faster_encoder(self.encoder)
11131137
if seq_len is None:
11141138
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
11151139
seq_len = paddle.sum(paddle.cast(
@@ -1207,12 +1231,13 @@ def forward(self,
12071231
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else getattr(
12081232
self._model, 'decoder_start_token_id', None)
12091233

1210-
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
1234+
#(gongenlei) Not enable_faster_encoder temporarily
12111235
if encoder_output is None:
1236+
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
12121237
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
12131238
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
12141239
input_ids, model_kwargs)["encoder_output"]
1215-
self.encoder = disable_faster_encoder(self.encoder)
1240+
self.encoder = disable_faster_encoder(self.encoder)
12161241
batch_size = paddle.shape(encoder_output)[0]
12171242
if seq_len is None:
12181243
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."

0 commit comments

Comments
 (0)