File tree 1 file changed +21
-8
lines changed
1 file changed +21
-8
lines changed Original file line number Diff line number Diff line change @@ -128,15 +128,28 @@ def _construct_input_spec(self):
128
128
"""
129
129
130
130
def _get_static_model_name (self ):
131
- names = []
131
+ model_candidates = []
132
132
for file_name in os .listdir (self ._task_path ):
133
- if PADDLE_INFERENCE_MODEL_SUFFIX in file_name :
134
- names .append (file_name [: - len (PADDLE_INFERENCE_MODEL_SUFFIX )])
135
- if len (names ) == 0 :
136
- raise IOError (f"{ self ._task_path } should include '{ PADDLE_INFERENCE_MODEL_SUFFIX } ' file." )
137
- if len (names ) > 1 :
138
- logger .warning (f"{ self ._task_path } includes more than one '{ PADDLE_INFERENCE_MODEL_SUFFIX } ' file." )
139
- return names [0 ]
133
+ if file_name .endswith (PADDLE_INFERENCE_MODEL_SUFFIX ):
134
+ prefix = file_name [: - len (PADDLE_INFERENCE_MODEL_SUFFIX )]
135
+ param_file = prefix + PADDLE_INFERENCE_WEIGHTS_SUFFIX
136
+ if os .path .exists (os .path .join (self ._task_path , param_file )):
137
+ model_candidates .append (prefix )
138
+
139
+ if not model_candidates :
140
+ raise IOError (
141
+ f"{ self ._task_path } should include at least one valid model structure file "
142
+ f"({ PADDLE_INFERENCE_MODEL_SUFFIX } ) with corresponding { PADDLE_INFERENCE_WEIGHTS_SUFFIX } ."
143
+ )
144
+
145
+ for preferred in ["inference" , "model" ]:
146
+ if preferred in model_candidates :
147
+ return preferred
148
+
149
+ if len (model_candidates ) > 1 :
150
+ logger .warning (f"{ self ._task_path } includes multiple model pairs. Defaulting to: { model_candidates [0 ]} " )
151
+
152
+ return model_candidates [0 ]
140
153
141
154
def _check_task_files (self ):
142
155
"""
You can’t perform that action at this time.
0 commit comments