Skip to content

Commit ef82edc

Browse files
committed
general spm converter
1 parent af6d275 commit ef82edc

File tree

2 files changed

+483
-2
lines changed

2 files changed

+483
-2
lines changed

src/transformers/convert_slow_tokenizer.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,9 @@ def __init__(self, *args):
547547

548548
super().__init__(*args)
549549

550+
# store extractor to convert tokens to ids from sp directly
551+
self.extractor = self.SpmExtractor(self.original_tokenizer.vocab_file)
552+
550553
# from .utils import sentencepiece_model_pb2 as model_pb2
551554
model_pb2 = import_protobuf()
552555

@@ -1320,6 +1323,59 @@ def decoder(self, replacement, add_prefix_space):
13201323
]
13211324
)
13221325

1326+
class GeneralSPMConverter(SpmConverter):
1327+
handle_byte_fallback = True
1328+
1329+
def vocab(self, proto):
1330+
vocab = [
1331+
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
1332+
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
1333+
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
1334+
]
1335+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1336+
return vocab
1337+
1338+
def unk_id(self, proto):
1339+
unk_id = 0
1340+
return unk_id
1341+
1342+
def decoder(self, replacement, add_prefix_space):
1343+
sequence = [
1344+
decoders.Replace("▁", " "),
1345+
decoders.ByteFallback(),
1346+
decoders.Fuse(),
1347+
]
1348+
if add_prefix_space:
1349+
sequence += [decoders.Strip(content=" ", left=1)]
1350+
return decoders.Sequence(sequence)
1351+
1352+
def normalizer(self, proto):
1353+
if getattr(self.original_tokenizer, "legacy", True):
1354+
sequence = []
1355+
if getattr(self.original_tokenizer, "add_prefix_space", True):
1356+
sequence += [normalizers.Prepend(prepend="▁")]
1357+
sequence += [normalizers.Replace(pattern=" ", content="▁")]
1358+
return normalizers.Sequence(sequence)
1359+
return None # non-legacy, no normalizer
1360+
1361+
def pre_tokenizer(self, replacement, add_prefix_space):
1362+
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
1363+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1364+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1365+
return None
1366+
1367+
def post_processor(self):
1368+
# return None
1369+
single = f"{(self.original_tokenizer.bos_token + ':0 ') if self.original_tokenizer.add_bos_token else ''}$A:0{(' ' + self.original_tokenizer.eos_token + ':0') if self.original_tokenizer.add_eos_token else ''}"
1370+
pair = f"{single}{(' ' + self.original_tokenizer.bos_token + ':1') if self.original_tokenizer.add_bos_token else ''} $B:1{(' ' + self.original_tokenizer.eos_token + ':1') if self.original_tokenizer.add_eos_token else ''}"
1371+
return processors.TemplateProcessing(
1372+
single=single,
1373+
pair=pair,
1374+
special_tokens=[
1375+
("<bos>", self.original_tokenizer.convert_tokens_to_ids("<bos>")),
1376+
("</eos>", self.original_tokenizer.convert_tokens_to_ids("</eos>")),
1377+
],
1378+
)
13231379

13241380
class LlamaConverter(SpmConverter):
13251381
handle_byte_fallback = True
@@ -1363,8 +1419,17 @@ def pre_tokenizer(self, replacement, add_prefix_space):
13631419
return None
13641420

13651421
def post_processor(self):
1366-
# the processor is defined in the LlamaTokenizerFast class.
1367-
return None
1422+
# return None
1423+
single = f"{(self.original_tokenizer.bos_token + ':0 ') if self.original_tokenizer.add_bos_token else ''}$A:0{(' ' + self.original_tokenizer.eos_token + ':0') if self.original_tokenizer.add_eos_token else ''}"
1424+
pair = f"{single}{(' ' + self.original_tokenizer.bos_token + ':1') if self.original_tokenizer.add_bos_token else ''} $B:1{(' ' + self.original_tokenizer.eos_token + ':1') if self.original_tokenizer.add_eos_token else ''}"
1425+
return processors.TemplateProcessing(
1426+
single=single,
1427+
pair=pair,
1428+
special_tokens=[
1429+
("<bos>", self.original_tokenizer.convert_tokens_to_ids("<bos>")),
1430+
("</eos>", self.original_tokenizer.convert_tokens_to_ids("</eos>")),
1431+
],
1432+
)
13681433

13691434

13701435
class MarkupLMConverter(Converter):
@@ -1685,6 +1750,7 @@ def converted(self) -> Tokenizer:
16851750
"RobertaTokenizer": RobertaConverter,
16861751
"RoFormerTokenizer": RoFormerConverter,
16871752
"SeamlessM4TTokenizer": SeamlessM4TConverter,
1753+
"SPMTokenizer": GeneralSPMConverter,
16881754
"SqueezeBertTokenizer": BertConverter,
16891755
"T5Tokenizer": T5Converter,
16901756
"UdopTokenizer": UdopConverter,

0 commit comments

Comments
 (0)