Skip to content

Commit a87c723

Browse files
authored
Fix trailing whitespace token handling (#64)
1 parent ab388a8 commit a87c723

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

spacy_stanza/tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def __call__(self, text):
108108
)
109109
offset = 0
110110
for i, word in enumerate(words):
111-
if word.isspace() and word != snlp_tokens[i + offset].text:
111+
if word.isspace() and (
112+
i + offset >= len(snlp_tokens) or word != snlp_tokens[i + offset].text
113+
):
112114
# insert a space token
113115
pos.append("SPACE")
114116
tags.append("_SP")

tests/test_language.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_spacy_stanza_english():
5555
assert doc.ents[1].label_ == "GPE"
5656

5757
# Test whitespace alignment
58-
doc = nlp(" Barack Obama was born\n\nin Hawaii.")
58+
doc = nlp(" Barack Obama was born\n\nin Hawaii.\n")
5959
assert [t.pos_ for t in doc] == [
6060
"SPACE",
6161
"PROPN",
@@ -69,6 +69,7 @@ def test_spacy_stanza_english():
6969
"ADP",
7070
"PROPN",
7171
"PUNCT",
72+
"SPACE",
7273
]
7374
assert [t.dep_ for t in doc] == [
7475
"",
@@ -83,14 +84,24 @@ def test_spacy_stanza_english():
8384
"case",
8485
"root",
8586
"punct",
87+
"",
8688
]
87-
assert [t.head.i for t in doc] == [0, 7, 2, 1, 4, 7, 6, 7, 8, 10, 10, 10]
89+
assert [t.head.i for t in doc] == [0, 7, 2, 1, 4, 7, 6, 7, 8, 10, 10, 10, 12]
8890
assert len(doc.ents) == 2
8991
assert doc.ents[0].text == "Barack Obama"
9092
assert doc.ents[0].label_ == "PERSON"
9193
assert doc.ents[1].text == "Hawaii"
9294
assert doc.ents[1].label_ == "GPE"
9395

96+
# Test trailing whitespace handling
97+
doc = nlp("a ")
98+
doc = nlp("a ")
99+
doc = nlp("a \n")
100+
doc = nlp("\n ")
101+
doc = nlp("\t ")
102+
doc = nlp("a\n ")
103+
doc = nlp("a \t ")
104+
94105
# Test serialization
95106
reloaded_nlp = spacy_stanza.load_pipeline(lang).from_bytes(nlp.to_bytes())
96107
assert reloaded_nlp.config.to_str() == nlp.config.to_str()

0 commit comments

Comments
 (0)