|
18 | 18 | from collections import OrderedDict
|
19 | 19 | import numpy as np
|
20 | 20 | import cv2
|
| 21 | +from pathlib import Path |
21 | 22 |
|
22 | 23 | from ..config import PathField, StringField, ConfigError, ListInputsField
|
23 | 24 | from ..logging import print_info
|
24 | 25 | from .launcher import Launcher, LauncherConfigValidator
|
25 |
| -from ..utils import get_or_parse_value |
| 26 | +from ..utils import get_or_parse_value, get_path |
26 | 27 |
|
27 | 28 | DEVICE_REGEX = r'(?P<device>cpu$|gpu|gpu_fp16)?'
|
28 | 29 | BACKEND_REGEX = r'(?P<backend>ocv|ie)?'
|
@@ -63,8 +64,8 @@ class OpenCVLauncher(Launcher):
|
63 | 64 | def parameters(cls):
|
64 | 65 | parameters = super().parameters()
|
65 | 66 | parameters.update({
|
66 |
| - 'model': PathField(description="Path to model file."), |
67 |
| - 'weights': PathField(description="Path to weights file.", optional=True, default='', check_exists=False), |
| 67 | + 'model': PathField(description="Path to model file.", file_or_directory=True), |
| 68 | + 'weights': PathField(description="Path to weights file.", optional=True, check_exists=False, file_or_directory=True), |
68 | 69 | 'device': StringField(
|
69 | 70 | regex=DEVICE_REGEX, choices=OpenCVLauncher.TARGET_DEVICES.keys(),
|
70 | 71 | description="Device name: {}".format(', '.join(OpenCVLauncher.TARGET_DEVICES.keys()))
|
@@ -100,8 +101,10 @@ def __init__(self, config_entry: dict, *args, **kwargs):
|
100 | 101 | raise ConfigError('{} is not supported device'.format(selected_device))
|
101 | 102 |
|
102 | 103 | if not self._delayed_model_loading:
|
103 |
| - self.model = self.get_value_from_config('model') |
104 |
| - self.weights = self.get_value_from_config('weights') |
| 104 | + self.model, self.weights = self.automatic_model_search(self._model_name, |
| 105 | + self.get_value_from_config('model'), self.get_value_from_config('weights'), |
| 106 | + self.get_value_from_config('_model_type') |
| 107 | + ) |
105 | 108 | self.network = self.create_network(self.model, self.weights)
|
106 | 109 | self._inputs_shapes = self.get_inputs_from_config(self.config)
|
107 | 110 | self.network.setInputsNames(list(self._inputs_shapes.keys()))
|
@@ -130,6 +133,77 @@ def batch(self):
|
130 | 133 | def output_blob(self):
|
131 | 134 | return next(iter(self.output_names))
|
132 | 135 |
|
| 136 | + def automatic_model_search(self, model_name, model_cfg, weights_cfg, model_type=None): |
| 137 | + model_type_ext = { |
| 138 | + 'xml': 'xml', |
| 139 | + 'blob': 'blob', |
| 140 | + 'onnx': 'onnx', |
| 141 | + 'caffe': 'prototxt', |
| 142 | + 'paddle': 'pdmodel', |
| 143 | + 'tf': 'pb' |
| 144 | + } |
| 145 | + def get_model_by_suffix(model_name, model_dir, suffix): |
| 146 | + model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix))) |
| 147 | + if not model_list: |
| 148 | + model_list = list(Path(model_dir).glob('*.{}'.format(suffix))) |
| 149 | + if not model_list: |
| 150 | + model_list = list(Path(model_dir).parent.rglob('*.{}'.format(suffix))) |
| 151 | + return model_list |
| 152 | + |
| 153 | + def get_model(): |
| 154 | + model = Path(model_cfg) |
| 155 | + if not model.is_dir(): |
| 156 | + accepted_suffixes = list(model_type_ext.values()) |
| 157 | + if model.suffix[1:] not in accepted_suffixes: |
| 158 | + raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes)) |
| 159 | + print_info('Found model {}'.format(model)) |
| 160 | + return model, model.suffix == '.blob' |
| 161 | + model_list = [] |
| 162 | + if model_type is not None: |
| 163 | + model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type]) |
| 164 | + else: |
| 165 | + for ext in model_type_ext.values(): |
| 166 | + model_list = get_model_by_suffix(model_name, model, ext) |
| 167 | + if model_list: |
| 168 | + break |
| 169 | + if not model_list: |
| 170 | + raise ConfigError('suitable model is not found') |
| 171 | + if len(model_list) != 1: |
| 172 | + raise ConfigError('More than one model matched, please specify explicitly') |
| 173 | + model = model_list[0] |
| 174 | + print_info('Found model {}'.format(model)) |
| 175 | + return model, model.suffix == '.blob' |
| 176 | + |
| 177 | + model, is_blob = get_model() |
| 178 | + if is_blob: |
| 179 | + return model, None |
| 180 | + weights = weights_cfg |
| 181 | + if model.suffix == '.pdmodel': |
| 182 | + weights = self.get_value_from_config('params') |
| 183 | + if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx': |
| 184 | + weights_dir = weights or model.parent |
| 185 | + weights_list = [] |
| 186 | + if model.suffix == '.xml': |
| 187 | + weights = Path(weights_dir) / model.name.replace('xml', 'bin') |
| 188 | + print(weights) |
| 189 | + else: |
| 190 | + if model.suffix == '.prototxt': |
| 191 | + weights_list = list(Path(weights_dir).glob('*.{}'.format('caffemodel'))) |
| 192 | + elif model.suffix == '.pdmodel': |
| 193 | + weights_list = list(Path(weights_dir).glob('*.{}'.format('pdiparams'))) |
| 194 | + if not weights_list: |
| 195 | + raise ConfigError('Suitable weights is not detected') |
| 196 | + if len(weights_list) != 1: |
| 197 | + raise ConfigError('Several suitable weights found, please specify required explicitly') |
| 198 | + weights = weights_list[0] |
| 199 | + if weights is not None: |
| 200 | + accepted_weights_suffixes = ['.bin', '.caffemodel', '.pdiparams'] |
| 201 | + if weights.suffix not in accepted_weights_suffixes: |
| 202 | + raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes)) |
| 203 | + print_info('Found weights {}'.format(get_path(weights))) |
| 204 | + |
| 205 | + return model, weights |
| 206 | + |
133 | 207 | def predict(self, inputs, metadata=None, **kwargs):
|
134 | 208 | """
|
135 | 209 | Args:
|
|
0 commit comments