Skip to content

Commit 3e68984

Browse files
authored
Fix WordTag decode bug (#1642)
* fix wordtag decode * Update README.md * Update README.md * Update codestyle
1 parent 28734ab commit 3e68984

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

paddlenlp/taskflow/knowledge_mining.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ class WordTagTask(Task):
191191
def __init__(self,
192192
model,
193193
task,
194-
batch_size=1,
195194
params_path=None,
196195
tag_path=None,
197196
term_schema_path=None,
@@ -419,11 +418,12 @@ def _reset_offset(self, pred_words):
419418
def _decode(self, batch_texts, batch_pred_tags):
420419
batch_results = []
421420
for sent_index in range(len(batch_texts)):
421+
sent = batch_texts[sent_index]
422422
tags = [
423423
self._index_to_tags[index]
424-
for index in batch_pred_tags[sent_index][self.summary_num:-1]
424+
for index in batch_pred_tags[sent_index][self.summary_num:len(
425+
sent) + self.summary_num]
425426
]
426-
sent = batch_texts[sent_index]
427427
if self._custom:
428428
self._custom.parse_customization(sent, tags, prefix=True)
429429
sent_out = []

paddlenlp/taskflow/named_entity_recognition.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,12 @@ def __init__(self, model, task, **kwargs):
8787
def _decode(self, batch_texts, batch_pred_tags):
8888
batch_results = []
8989
for sent_index in range(len(batch_texts)):
90+
sent = batch_texts[sent_index]
9091
tags = [
9192
self._index_to_tags[index]
92-
for index in batch_pred_tags[sent_index][self.summary_num:-1]
93+
for index in batch_pred_tags[sent_index][self.summary_num:len(
94+
sent) + self.summary_num]
9395
]
94-
sent = batch_texts[sent_index]
9596
if self._custom:
9697
self._custom.parse_customization(sent, tags, prefix=True)
9798
sent_out = []
@@ -100,12 +101,12 @@ def _decode(self, batch_texts, batch_pred_tags):
100101
for ind, tag in enumerate(tags):
101102
if partial_word == "":
102103
partial_word = sent[ind]
103-
tags_out.append(tag.split('-')[1])
104+
tags_out.append(tag.split('-')[-1])
104105
continue
105106
if tag.startswith("B") or tag.startswith("S") or tag.startswith(
106107
"O"):
107108
sent_out.append(partial_word)
108-
tags_out.append(tag.split('-')[1])
109+
tags_out.append(tag.split('-')[-1])
109110
partial_word = sent[ind]
110111
continue
111112
partial_word += sent[ind]

0 commit comments

Comments
 (0)