Skip to content

Commit 3366ba9

Browse files
committed
add util
1 parent 8d227da commit 3366ba9

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from transformers import PreTrainedTokenizerFast
2+
from transformers.models.llama.tokenization_spm import SPMTokenizer
3+
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
4+
5+
6+
def load_spm_tokenizer(model_path: str) -> SPMTokenizer:
7+
"""
8+
Load a slow SentencePiece tokenizer from the specified model path.
9+
"""
10+
return SPMTokenizer.from_pretrained(
11+
model_path,
12+
unk_token="<unk>",
13+
pad_token="<pad>",
14+
bos_token="<bos>",
15+
eos_token="<eos>",
16+
)
17+
18+
19+
def load_fast_spm_tokenizer(model_path: str) -> PreTrainedTokenizerFast:
20+
"""
21+
Load a fast tokenizer using the slow SPMTokenizer and convert it.
22+
"""
23+
slow_tokenizer = SPMTokenizer.from_pretrained(
24+
model_path,
25+
unk_token="<unk>",
26+
pad_token="<pad>",
27+
bos_token="<bos>",
28+
eos_token="<eos>",
29+
do_lower_case=False,
30+
add_bos_token=True,
31+
)
32+
return PreTrainedTokenizerFast(
33+
tokenizer_object=convert_slow_tokenizer(slow_tokenizer)
34+
)
35+
36+
37+
def compare_tokenizers(sp_tokenizer, fast_tokenizer, text: str):
38+
"""
39+
Assert that tokenization and decoding results are identical between slow and fast tokenizers.
40+
"""
41+
sp_tokens = sp_tokenizer.tokenize(text)
42+
fast_tokens = fast_tokenizer.tokenize(text)
43+
assert sp_tokens == fast_tokens, (
44+
f"\nToken mismatch for input: {repr(text)}\n"
45+
f"SPM tokens : {sp_tokens}\n"
46+
f"Fast tokens: {fast_tokens}"
47+
)
48+
49+
sp_ids = sp_tokenizer.encode(text)
50+
fast_ids = fast_tokenizer.encode(text)
51+
assert sp_ids == fast_ids, (
52+
f"\nID mismatch for input: {repr(text)}\n"
53+
f"SPM IDs : {sp_ids}\n"
54+
f"Fast IDs: {fast_ids}"
55+
)
56+
57+
sp_decoded = sp_tokenizer.decode(sp_ids)
58+
fast_decoded = fast_tokenizer.decode(fast_ids)
59+
assert sp_decoded == fast_decoded, (
60+
f"\nDecoded output mismatch for input: {repr(text)}\n"
61+
f"SPM decoded : {sp_decoded}\n"
62+
f"Fast decoded: {fast_decoded}"
63+
)
64+
65+
66+
TEST_STRINGS = [
67+
"Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61",
68+
"The following string should be properly encoded: Hello.",
69+
"But ird and ปี ird ด",
70+
"This is a test.",
71+
"Hello world! Multiple spaces here.",
72+
"Hi Hello with double space.",
73+
" Leading spaces.",
74+
"Trailing spaces",
75+
"<s>Special token at start",
76+
"Text with <s> special token in the middle",
77+
"Text ending with special token <s>",
78+
"<s> Special token with spaces",
79+
"<s>I immediately after special token",
80+
"Hello, <s>, with commas",
81+
"生活的真谛是 Chinese characters",
82+
"áéíóúñ Accented characters",
83+
"ا العربية Arabic text",
84+
"Numbers 12345 and symbols !@#$%^&*()",
85+
"Line with\nmultiple\nbreaks",
86+
]
87+
88+
89+
def main():
90+
model_path = "../../../local-gemma-7b/tokenizer.model" # Adjust to your local path
91+
sp_tokenizer = load_spm_tokenizer(model_path)
92+
fast_tokenizer = load_fast_spm_tokenizer(model_path)
93+
94+
for text in TEST_STRINGS:
95+
compare_tokenizers(sp_tokenizer, fast_tokenizer, text)
96+
97+
print("All tokenizer outputs match ✔️")
98+
99+
100+
if __name__ == "__main__":
101+
main()

0 commit comments

Comments
 (0)