@@ -547,6 +547,9 @@ def __init__(self, *args):
547
547
548
548
super ().__init__ (* args )
549
549
550
+ # store extractor to convert tokens to ids from sp directly
551
+ self .extractor = self .SpmExtractor (self .original_tokenizer .vocab_file )
552
+
550
553
# from .utils import sentencepiece_model_pb2 as model_pb2
551
554
model_pb2 = import_protobuf ()
552
555
@@ -1320,6 +1323,59 @@ def decoder(self, replacement, add_prefix_space):
1320
1323
]
1321
1324
)
1322
1325
1326
+ class GeneralSPMConverter (SpmConverter ):
1327
+ handle_byte_fallback = True
1328
+
1329
+ def vocab (self , proto ):
1330
+ vocab = [
1331
+ (self .original_tokenizer .convert_ids_to_tokens (0 ), 0.0 ),
1332
+ (self .original_tokenizer .convert_ids_to_tokens (1 ), 0.0 ),
1333
+ (self .original_tokenizer .convert_ids_to_tokens (2 ), 0.0 ),
1334
+ ]
1335
+ vocab += [(piece .piece , piece .score ) for piece in proto .pieces [3 :]]
1336
+ return vocab
1337
+
1338
+ def unk_id (self , proto ):
1339
+ unk_id = 0
1340
+ return unk_id
1341
+
1342
+ def decoder (self , replacement , add_prefix_space ):
1343
+ sequence = [
1344
+ decoders .Replace ("▁" , " " ),
1345
+ decoders .ByteFallback (),
1346
+ decoders .Fuse (),
1347
+ ]
1348
+ if add_prefix_space :
1349
+ sequence += [decoders .Strip (content = " " , left = 1 )]
1350
+ return decoders .Sequence (sequence )
1351
+
1352
+ def normalizer (self , proto ):
1353
+ if getattr (self .original_tokenizer , "legacy" , True ):
1354
+ sequence = []
1355
+ if getattr (self .original_tokenizer , "add_prefix_space" , True ):
1356
+ sequence += [normalizers .Prepend (prepend = "▁" )]
1357
+ sequence += [normalizers .Replace (pattern = " " , content = "▁" )]
1358
+ return normalizers .Sequence (sequence )
1359
+ return None # non-legacy, no normalizer
1360
+
1361
+ def pre_tokenizer (self , replacement , add_prefix_space ):
1362
+ if not getattr (self .original_tokenizer , "legacy" , True ): # non-legacy, we need a replace
1363
+ prepend_scheme = _get_prepend_scheme (add_prefix_space , self .original_tokenizer )
1364
+ return pre_tokenizers .Metaspace (replacement = replacement , prepend_scheme = prepend_scheme , split = False )
1365
+ return None
1366
+
1367
+ def post_processor (self ):
1368
+ # return None
1369
+ single = f"{ (self .original_tokenizer .bos_token + ':0 ' ) if self .original_tokenizer .add_bos_token else '' } $A:0{ (' ' + self .original_tokenizer .eos_token + ':0' ) if self .original_tokenizer .add_eos_token else '' } "
1370
+ pair = f"{ single } { (' ' + self .original_tokenizer .bos_token + ':1' ) if self .original_tokenizer .add_bos_token else '' } $B:1{ (' ' + self .original_tokenizer .eos_token + ':1' ) if self .original_tokenizer .add_eos_token else '' } "
1371
+ return processors .TemplateProcessing (
1372
+ single = single ,
1373
+ pair = pair ,
1374
+ special_tokens = [
1375
+ ("<bos>" , self .original_tokenizer .convert_tokens_to_ids ("<bos>" )),
1376
+ ("</eos>" , self .original_tokenizer .convert_tokens_to_ids ("</eos>" )),
1377
+ ],
1378
+ )
1323
1379
1324
1380
class LlamaConverter (SpmConverter ):
1325
1381
handle_byte_fallback = True
@@ -1363,8 +1419,17 @@ def pre_tokenizer(self, replacement, add_prefix_space):
1363
1419
return None
1364
1420
1365
1421
def post_processor (self ):
1366
- # the processor is defined in the LlamaTokenizerFast class.
1367
- return None
1422
+ # return None
1423
+ single = f"{ (self .original_tokenizer .bos_token + ':0 ' ) if self .original_tokenizer .add_bos_token else '' } $A:0{ (' ' + self .original_tokenizer .eos_token + ':0' ) if self .original_tokenizer .add_eos_token else '' } "
1424
+ pair = f"{ single } { (' ' + self .original_tokenizer .bos_token + ':1' ) if self .original_tokenizer .add_bos_token else '' } $B:1{ (' ' + self .original_tokenizer .eos_token + ':1' ) if self .original_tokenizer .add_eos_token else '' } "
1425
+ return processors .TemplateProcessing (
1426
+ single = single ,
1427
+ pair = pair ,
1428
+ special_tokens = [
1429
+ ("<bos>" , self .original_tokenizer .convert_tokens_to_ids ("<bos>" )),
1430
+ ("</eos>" , self .original_tokenizer .convert_tokens_to_ids ("</eos>" )),
1431
+ ],
1432
+ )
1368
1433
1369
1434
1370
1435
class MarkupLMConverter (Converter ):
@@ -1685,6 +1750,7 @@ def converted(self) -> Tokenizer:
1685
1750
"RobertaTokenizer" : RobertaConverter ,
1686
1751
"RoFormerTokenizer" : RoFormerConverter ,
1687
1752
"SeamlessM4TTokenizer" : SeamlessM4TConverter ,
1753
+ "SPMTokenizer" : GeneralSPMConverter ,
1688
1754
"SqueezeBertTokenizer" : BertConverter ,
1689
1755
"T5Tokenizer" : T5Converter ,
1690
1756
"UdopTokenizer" : UdopConverter ,
0 commit comments