1717import importlib .util
1818import json
1919import platform
20+ from collections import defaultdict
2021from functools import lru_cache
2122from typing import Any , Dict , List , Literal , Optional , Tuple , Union
2223
3031 get_paddle_version ,
3132)
3233from ...utils .flags import USE_PIR_TRT
34+ from .misc import is_mkldnn_available
3335from .model_paths import ModelPaths
3436
3537
@@ -186,24 +188,23 @@ def suggest_inference_backend_and_config(
186188 hpi_config .pdx_model_name
187189 ].copy ()
188190
191+ if not is_mkldnn_available ():
192+ if "paddle_mkldnn" in supported_pseudo_backends :
193+ supported_pseudo_backends .remove ("paddle_mkldnn" )
194+
189195 # XXX
190196 if not (
191197 USE_PIR_TRT
192198 and importlib .util .find_spec ("tensorrt" )
193199 and ctypes .util .find_library ("nvinfer" )
194200 ):
195- if (
196- "paddle_tensorrt" in supported_pseudo_backends
197- or "paddle_tensorrt_fp16" in supported_pseudo_backends
198- ):
199- supported_pseudo_backends .append ("paddle" )
200201 if "paddle_tensorrt" in supported_pseudo_backends :
201202 supported_pseudo_backends .remove ("paddle_tensorrt" )
202203 if "paddle_tensorrt_fp16" in supported_pseudo_backends :
203204 supported_pseudo_backends .remove ("paddle_tensorrt_fp16" )
204205
205- candidate_backends = []
206- backend_to_pseudo_backend = {}
206+ supported_backends = []
207+ backend_to_pseudo_backends = defaultdict ( list )
207208 for pb in supported_pseudo_backends :
208209 if pb .startswith ("paddle" ):
209210 backend = "paddle"
@@ -213,34 +214,38 @@ def suggest_inference_backend_and_config(
213214 backend = pb
214215 if available_backends is not None and backend not in available_backends :
215216 continue
216- candidate_backends .append (backend )
217- backend_to_pseudo_backend [backend ] = pb
217+ supported_backends .append (backend )
218+ backend_to_pseudo_backends [backend ]. append ( pb )
218219
219- if not candidate_backends :
220+ if not supported_backends :
220221 return None , "No inference backend can be selected."
221222
222223 if hpi_config .backend is not None :
223- if hpi_config .backend not in candidate_backends :
224+ if hpi_config .backend not in supported_backends :
224225 return (
225226 None ,
226227 f"{ repr (hpi_config .backend )} is not a supported inference backend." ,
227228 )
228229 suggested_backend = hpi_config .backend
230+ pseudo_backends = backend_to_pseudo_backends [suggested_backend ]
231+ pseudo_backend = pseudo_backends [0 ]
229232 else :
230- # The first backend is the preferred one.
231- suggested_backend = candidate_backends [0 ]
233+ # Prefer the first one.
234+ suggested_backend = supported_backends [0 ]
235+ pseudo_backend = supported_pseudo_backends [0 ]
232236
233237 suggested_backend_config = {}
234238 if suggested_backend == "paddle" :
235- pseudo_backend = backend_to_pseudo_backend ["paddle" ]
236239 assert pseudo_backend in (
237240 "paddle" ,
238241 "paddle_fp16" ,
239242 "paddle_mkldnn" ,
240243 "paddle_tensorrt" ,
241244 "paddle_tensorrt_fp16" ,
242245 ), pseudo_backend
243- if pseudo_backend == "paddle_fp16" :
246+ if pseudo_backend == "paddle" :
247+ suggested_backend_config .update ({"run_mode" : "paddle" })
248+ elif pseudo_backend == "paddle_fp16" :
244249 suggested_backend_config .update ({"run_mode" : "paddle_fp16" })
245250 elif pseudo_backend == "paddle_mkldnn" :
246251 suggested_backend_config .update ({"run_mode" : "mkldnn" })
@@ -250,7 +255,6 @@ def suggest_inference_backend_and_config(
250255 # TODO: Check if the target device supports FP16.
251256 suggested_backend_config .update ({"run_mode" : "trt_fp16" })
252257 elif suggested_backend == "tensorrt" :
253- pseudo_backend = backend_to_pseudo_backend ["tensorrt" ]
254258 assert pseudo_backend in ("tensorrt" , "tensorrt_fp16" ), pseudo_backend
255259 if pseudo_backend == "tensorrt_fp16" :
256260 suggested_backend_config .update ({"precision" : "fp16" })
0 commit comments