Skip to content

Commit 868f6ee

Browse files
authored
[Taskflow] Fix the recognition bug of json format with both PIR suffix and id2label (#10487)
1 parent 38828ba commit 868f6ee

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

paddlenlp/taskflow/task.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,28 @@ def _construct_input_spec(self):
128128
"""
129129

130130
def _get_static_model_name(self):
131-
names = []
131+
model_candidates = []
132132
for file_name in os.listdir(self._task_path):
133-
if PADDLE_INFERENCE_MODEL_SUFFIX in file_name:
134-
names.append(file_name[: -len(PADDLE_INFERENCE_MODEL_SUFFIX)])
135-
if len(names) == 0:
136-
raise IOError(f"{self._task_path} should include '{PADDLE_INFERENCE_MODEL_SUFFIX}' file.")
137-
if len(names) > 1:
138-
logger.warning(f"{self._task_path} includes more than one '{PADDLE_INFERENCE_MODEL_SUFFIX}' file.")
139-
return names[0]
133+
if file_name.endswith(PADDLE_INFERENCE_MODEL_SUFFIX):
134+
prefix = file_name[: -len(PADDLE_INFERENCE_MODEL_SUFFIX)]
135+
param_file = prefix + PADDLE_INFERENCE_WEIGHTS_SUFFIX
136+
if os.path.exists(os.path.join(self._task_path, param_file)):
137+
model_candidates.append(prefix)
138+
139+
if not model_candidates:
140+
raise IOError(
141+
f"{self._task_path} should include at least one valid model structure file "
142+
f"({PADDLE_INFERENCE_MODEL_SUFFIX}) with corresponding {PADDLE_INFERENCE_WEIGHTS_SUFFIX}."
143+
)
144+
145+
for preferred in ["inference", "model"]:
146+
if preferred in model_candidates:
147+
return preferred
148+
149+
if len(model_candidates) > 1:
150+
logger.warning(f"{self._task_path} includes multiple model pairs. Defaulting to: {model_candidates[0]}")
151+
152+
return model_candidates[0]
140153

141154
def _check_task_files(self):
142155
"""

0 commit comments

Comments
 (0)