Skip to content

Commit aa99dca

Browse files
committed
add automatic model search for opencv launcher
1 parent 10f54a1 commit aa99dca

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/config/config_reader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ def provide_precision_and_layout(launchers, input_precisions, input_layouts):
925925

926926

927927
def provide_model_type(launcher, arguments):
928-
if 'model_type' in arguments:
928+
if 'model_type' in arguments and arguments.model_type is not None:
929929
launcher['_model_type'] = arguments.model_type
930930
if launcher['framework'] in ['dlsdk', 'openvino', 'g-api'] and 'model_is_blob' in arguments:
931931
launcher['_model_is_blob'] = arguments.model_is_blob

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/opencv_launcher.py

+79-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from collections import OrderedDict
1919
import numpy as np
2020
import cv2
21+
from pathlib import Path
2122

2223
from ..config import PathField, StringField, ConfigError, ListInputsField
2324
from ..logging import print_info
2425
from .launcher import Launcher, LauncherConfigValidator
25-
from ..utils import get_or_parse_value
26+
from ..utils import get_or_parse_value, get_path
2627

2728
DEVICE_REGEX = r'(?P<device>cpu$|gpu|gpu_fp16)?'
2829
BACKEND_REGEX = r'(?P<backend>ocv|ie)?'
@@ -63,8 +64,8 @@ class OpenCVLauncher(Launcher):
6364
def parameters(cls):
6465
parameters = super().parameters()
6566
parameters.update({
66-
'model': PathField(description="Path to model file."),
67-
'weights': PathField(description="Path to weights file.", optional=True, default='', check_exists=False),
67+
'model': PathField(description="Path to model file.", file_or_directory=True),
68+
'weights': PathField(description="Path to weights file.", optional=True, check_exists=False, file_or_directory=True),
6869
'device': StringField(
6970
regex=DEVICE_REGEX, choices=OpenCVLauncher.TARGET_DEVICES.keys(),
7071
description="Device name: {}".format(', '.join(OpenCVLauncher.TARGET_DEVICES.keys()))
@@ -100,8 +101,10 @@ def __init__(self, config_entry: dict, *args, **kwargs):
100101
raise ConfigError('{} is not supported device'.format(selected_device))
101102

102103
if not self._delayed_model_loading:
103-
self.model = self.get_value_from_config('model')
104-
self.weights = self.get_value_from_config('weights')
104+
self.model, self.weights = self.automatic_model_search(self._model_name,
105+
self.get_value_from_config('model'), self.get_value_from_config('weights'),
106+
self.get_value_from_config('_model_type')
107+
)
105108
self.network = self.create_network(self.model, self.weights)
106109
self._inputs_shapes = self.get_inputs_from_config(self.config)
107110
self.network.setInputsNames(list(self._inputs_shapes.keys()))
@@ -130,6 +133,77 @@ def batch(self):
130133
def output_blob(self):
131134
return next(iter(self.output_names))
132135

136+
def automatic_model_search(self, model_name, model_cfg, weights_cfg, model_type=None):
137+
model_type_ext = {
138+
'xml': 'xml',
139+
'blob': 'blob',
140+
'onnx': 'onnx',
141+
'caffe': 'prototxt',
142+
'paddle': 'pdmodel',
143+
'tf': 'pb'
144+
}
145+
def get_model_by_suffix(model_name, model_dir, suffix):
146+
model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix)))
147+
if not model_list:
148+
model_list = list(Path(model_dir).glob('*.{}'.format(suffix)))
149+
if not model_list:
150+
model_list = list(Path(model_dir).parent.rglob('*.{}'.format(suffix)))
151+
return model_list
152+
153+
def get_model():
154+
model = Path(model_cfg)
155+
if not model.is_dir():
156+
accepted_suffixes = list(model_type_ext.values())
157+
if model.suffix[1:] not in accepted_suffixes:
158+
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
159+
print_info('Found model {}'.format(model))
160+
return model, model.suffix == '.blob'
161+
model_list = []
162+
if model_type is not None:
163+
model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type])
164+
else:
165+
for ext in model_type_ext.values():
166+
model_list = get_model_by_suffix(model_name, model, ext)
167+
if model_list:
168+
break
169+
if not model_list:
170+
raise ConfigError('suitable model is not found')
171+
if len(model_list) != 1:
172+
raise ConfigError('More than one model matched, please specify explicitly')
173+
model = model_list[0]
174+
print_info('Found model {}'.format(model))
175+
return model, model.suffix == '.blob'
176+
177+
model, is_blob = get_model()
178+
if is_blob:
179+
return model, None
180+
weights = weights_cfg
181+
if model.suffix == '.pdmodel':
182+
weights = self.get_value_from_config('params')
183+
if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx':
184+
weights_dir = weights or model.parent
185+
weights_list = []
186+
if model.suffix == '.xml':
187+
weights = Path(weights_dir) / model.name.replace('xml', 'bin')
188+
print(weights)
189+
else:
190+
if model.suffix == '.prototxt':
191+
weights_list = list(Path(weights_dir).glob('*.{}'.format('caffemodel')))
192+
elif model.suffix == '.pdmodel':
193+
weights_list = list(Path(weights_dir).glob('*.{}'.format('pdiparams')))
194+
if not weights_list:
195+
raise ConfigError('Suitable weights is not detected')
196+
if len(weights_list) != 1:
197+
raise ConfigError('Several suitable weights found, please specify required explicitly')
198+
weights = weights_list[0]
199+
if weights is not None:
200+
accepted_weights_suffixes = ['.bin', '.caffemodel', '.pdiparams']
201+
if weights.suffix not in accepted_weights_suffixes:
202+
raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes))
203+
print_info('Found weights {}'.format(get_path(weights)))
204+
205+
return model, weights
206+
133207
def predict(self, inputs, metadata=None, **kwargs):
134208
"""
135209
Args:

0 commit comments

Comments
 (0)