Skip to content

Commit 7c34714

Browse files
committed
Add a feature for training a lemmatizer ignoring blank lemmas. Helps make lemmatizers that work on partially finished treebanks (such as Sindhi)
1 parent 2912c8c commit 7c34714

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

stanza/models/lemma/data.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,19 @@ def __iter__(self):
121121
yield self.__getitem__(i)
122122

123123
def raw_data(self):
124-
return self.load_doc(self.doc, self.args.get('caseless', False), self.eval)
124+
return self.load_doc(self.doc, self.args.get('caseless', False), self.args.get('skip_blank_lemmas', False), self.eval)
125125

126126
@staticmethod
127-
def load_doc(doc, caseless, evaluation):
127+
def load_doc(doc, caseless, skip_blank_lemmas, evaluation):
128128
if evaluation:
129129
data = doc.get([TEXT, UPOS, LEMMA])
130130
else:
131131
data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True)
132132
data = DataLoader.remove_goeswith(data)
133133
data = DataLoader.extract_correct_forms(data)
134134
data = DataLoader.resolve_none(data)
135+
if not evaluation and skip_blank_lemmas:
136+
data = DataLoader.skip_blank_lemmas(data)
135137
if caseless:
136138
data = DataLoader.lowercase_data(data)
137139
return data
@@ -202,6 +204,11 @@ def lowercase_data(data):
202204
token[0] = token[0].lower()
203205
return data
204206

207+
@staticmethod
208+
def skip_blank_lemmas(data):
209+
data = [x for x in data if x[2] != '_']
210+
return data
211+
205212
@staticmethod
206213
def resolve_none(data):
207214
# replace None to '_'

stanza/models/lemmatizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def build_argparse():
7878
parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_lemmatizer.pt", help="File name to save the model")
7979

8080
parser.add_argument('--caseless', default=False, action='store_true', help='Lowercase everything first before processing. This will happen automatically if 100%% of the data is caseless')
81+
parser.add_argument('--skip_blank_lemmas', default=False, action='store_true', help='Skip blank entries in the data files. Useful for training a lemmatizer from a partially annotated dataset')
8182

8283
parser.add_argument('--seed', type=int, default=1234)
8384
utils.add_device_args(parser)

stanza/tests/lemma/test_data.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,38 +69,60 @@
6969
4 ambulances ambulance NOUN NNS Number=Plur 3 obj 3:obj SpaceAfter=No
7070
"""
7171

72+
BLANKS_DATA = """
73+
# sent_id = weblog-juancole.com_juancole_20051126063000_ENG_20051126_063000-0018
74+
# text = Guerrillas killed an engineer, Asi Ali, from Tikrit.
75+
1 Guerrillas _ NOUN NNS Number=Plur 2 nsubj 2:nsubj _
76+
2 killed _ VERB VBD Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _
77+
3 an a DET DT Definite=Ind|PronType=Art 4 det 4:det _
78+
4 engineer _ NOUN NN Number=Sing 2 obj 2:obj SpaceAfter=No
79+
80+
""".lstrip()
81+
7282

7383
def test_load_document():
7484
train_doc = CoNLL.conll2doc(input_str=TRAIN_DATA)
75-
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
85+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)
7686
assert len(data) == 33 # meticulously counted by hand
7787
assert all(len(x) == 3 for x in data)
7888

79-
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
89+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
8090
assert len(data) == 33
8191
assert all(len(x) == 3 for x in data)
8292

8393
def test_load_goeswith():
8494
raw_data = TRAIN_DATA + GOESWITH_DATA
8595
train_doc = CoNLL.conll2doc(input_str=raw_data)
86-
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
96+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)
8797
assert len(data) == 36 # will be the same as in test_load_document with three additional words
8898
assert all(len(x) == 3 for x in data)
8999

90-
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
100+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
91101
assert len(data) == 33 # will be the same as in test_load_document, but with the trailing 3 GOESWITH removed
92102
assert all(len(x) == 3 for x in data)
93103

94104
def test_correct_form():
95105
raw_data = TRAIN_DATA + CORRECT_FORM_DATA
96106
train_doc = CoNLL.conll2doc(input_str=raw_data)
97-
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=True)
107+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=True)
98108
assert len(data) == 37
99109
# the 'targeting' correction should not be applied if evaluation=True
100110
# when evaluation=False, then the CorrectForms will be applied
101111
assert not any(x[0] == 'targeting' for x in data)
102112

103-
data = DataLoader.load_doc(train_doc, caseless=False, evaluation=False)
113+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
104114
assert len(data) == 38 # the same, but with an extra row so the model learns both 'targetting' and 'targeting'
105115
assert any(x[0] == 'targeting' for x in data)
106116
assert any(x[0] == 'targetting' for x in data)
117+
118+
def test_load_blank():
119+
raw_data = TRAIN_DATA + BLANKS_DATA
120+
train_doc = CoNLL.conll2doc(input_str=raw_data)
121+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=False, evaluation=False)
122+
assert len(data) == 37 # will be the same as in test_load_document with FOUR additional words
123+
assert all(len(x) == 3 for x in data)
124+
125+
data = DataLoader.load_doc(train_doc, caseless=False, skip_blank_lemmas=True, evaluation=False)
126+
assert len(data) == 34 # will be the same as in test_load_document, but one extra word is added. others were blank
127+
assert all(len(x) == 3 for x in data)
128+

0 commit comments

Comments
 (0)