66import collections
77
88from coremltools .converters .mil .input_types import InputType , ClassifierConfig
9- from coremltools .converters .mil .converter import _convert
9+ from coremltools .converters .mil .converter import mil_convert
1010from coremltools .converters .mil .mil import Program
1111from coremltools ._deps import _HAS_TORCH , _HAS_TF_1 , _HAS_TF_2
1212from coremltools .converters ._profile_utils import _profile
@@ -39,6 +39,7 @@ def convert(
3939 outputs = None ,
4040 classifier_config = None ,
4141 minimum_deployment_target = None ,
42+ convert_to = 'nn_proto' ,
4243 ** kwargs
4344):
4445 """
@@ -68,7 +69,7 @@ def convert(
6869 - Path to a `.pt` file
6970
7071 source: str (optional)
71- One of `auto`, `tensorflow`, or `pytorch`. `auto` determines the
72+ One of [ `auto`, `tensorflow`, `pytorch`, `mil`] . `auto` determines the
7273 framework automatically for most cases. Raise ValueError if it fails
7374 to determine the source framework.
7475
@@ -108,12 +109,22 @@ def convert(
108109
109110 minimum_deployment_target: coremltools.target enumeration (optional)
110111 - one of the members of enum "coremltools.target."
111- - When not-specified or None, converter aims for as minimum of a deployment target as possible
112+ - When not-specified or None, converter aims for as minimum of a
113+ deployment target as possible
114+
115+ convert_to: str (optional)
116+ - Must be one of ['nn_proto', 'mil'].
117+ - 'nn_proto': Returns MLModel containing a NeuralNetwork
118+ proto
119+ - 'mil': Returns MIL program object. MIL program is primarily used
120+ for debugging purpose and currently cannot be compiled to
121+ executable.
112122
113123 Returns
114124 -------
115- model: MLModel
116- A Core ML MLModel object
125+ model: `coremltools.models.MLModel` or
126+ `coremltools.converters.mil.Program`
127+ A Core ML MLModel object or MIL Program object (see `convert_to`)
117128
118129 Examples
119130 --------
@@ -157,24 +168,55 @@ def convert(
157168 See `here <https://coremltools.readme.io/docs/neural-network-conversion>`_ for
158169 more advanced options
159170 """
160- if minimum_deployment_target is not None and not isinstance (
161- minimum_deployment_target , AvailableTarget
162- ):
171+ _check_deployment_target (minimum_deployment_target )
172+ exact_source = _determine_source (model , source , outputs )
173+ _validate_inputs (model , exact_source , inputs , outputs , classifier_config ,
174+ ** kwargs )
175+
176+ mlmodel = mil_convert (
177+ model ,
178+ convert_from = exact_source ,
179+ convert_to = convert_to ,
180+ inputs = inputs ,
181+ outputs = outputs ,
182+ classifier_config = classifier_config ,
183+ ** kwargs
184+ )
185+
186+ if convert_to == 'mil' :
187+ return mlmodel # Returns the MIL program
188+
189+ if minimum_deployment_target is not None :
190+ check_deployment_compatibility (
191+ spec = mlmodel .get_spec (),
192+ representation = convert_to ,
193+ deployment_target = minimum_deployment_target ,
194+ )
195+
196+ gc .collect ()
197+
198+ mlmodel = _record_src_version (mlmodel , exact_source )
199+ mlmodel .user_defined_metadata [_METADATA_VERSION ] = ct_version
200+
201+ return mlmodel
202+
203+
204+ def _check_deployment_target (minimum_deployment_target ):
205+ if minimum_deployment_target is not None and \
206+ not isinstance (minimum_deployment_target , AvailableTarget ):
163207 msg = (
164208 "Unrecognized value of argument 'minimum_deployment_target': {}. "
165209 "It needs to be a member of 'coremltools.target' enumeration. "
166210 "For example, coremltools.target.iOS13"
167211 )
168212 raise TypeError (msg .format (minimum_deployment_target ))
169213
170- source = source .lower ()
171- if source not in {"auto" , "tensorflow" , "pytorch" }:
172- msg = (
173- 'Unrecognized value of argument "source": {}. '
174- 'It must be one of ["auto", "tensorflow", "pytorch"].'
175- )
176- raise ValueError (msg .format (source ))
177-
214+ def _validate_inputs (model , exact_source , inputs , outputs , classifier_config ,
215+ ** kwargs ):
216+ """
217+ Validate and process model, inputs, outputs, classifier_config based on
218+ `exact_source` (which cannot be `auto`)
219+ """
178220 def raise_if_duplicated (input_list ):
179221 # Detect duplicated inputs
180222 input_names = [t .name for t in input_list if t .name is not None ]
@@ -196,56 +238,11 @@ def raise_if_duplicated(input_list):
196238 msg = '"classifier_config" must be of type ClassifierConfig'
197239 raise ValueError (msg )
198240
199- if source == "tensorflow" and _HAS_TF_2 :
200- source = "tensorflow2"
201-
202- if source == "auto" and _HAS_TF_1 :
203- try :
204- loader = TF1Loader (model , outputs = outputs )
205- loader ._graph_def_from_model (outputs = outputs )
206- source = "tensorflow"
207- except :
208- pass
209-
210- if source == "auto" and _HAS_TF_2 :
211- try :
212- loader = TF2Loader (model , outputs = outputs )
213- loader ._graph_def_from_model (outputs = outputs )
214- source = "tensorflow2"
215- except :
216- pass
217-
218- if source == "auto" and _HAS_TORCH :
219- try :
220- pytorch_load (model )
221- source = "pytorch"
222- except :
223- pass
224-
225- if source == "auto" and isinstance (model , Program ):
226- source = "mil"
227-
228- convert_to = kwargs .get ("convert_to" , "nn_proto" )
229- kwargs .pop ("convert_to" , None )
230-
231- if source == "auto" :
232- msg = (
233- "Unable to determine the type of the model, i.e. the source framework. "
234- 'Please provide the value of argument "source", from one of '
235- '["tensorflow", "pytorch"]. Note that model conversion requires the '
236- "source package that generates the model. Please make sure you have "
237- "the appropriate version of source package installed. E.g., if you're "
238- "converting model originally trained with TensorFlow 1.14, make sure "
239- "you have `tensorflow==1.14` installed."
240- )
241- raise ValueError (msg )
242-
243- elif source in {"tensorflow" , "tensorflow2" }:
244-
245- if source == "tensorflow" and not _HAS_TF_1 :
246- raise ValueError (
247- 'Converter was called with source="tensorflow", but missing tensorflow package'
248- )
241+ if exact_source in {"tensorflow" , "tensorflow2" }:
242+ if exact_source == "tensorflow" and not _HAS_TF_1 :
243+ msg = 'Converter was called with source="tensorflow", ' + \
244+ 'but missing tensorflow package'
245+ raise ValueError (msg )
249246
250247 if inputs is not None :
251248 raise_if_duplicated (inputs )
@@ -255,17 +252,7 @@ def raise_if_duplicated(input_list):
255252 ):
256253 raise ValueError ("Input should be a list of TensorType or ImageType" )
257254
258- proto_spec = _convert (
259- model ,
260- convert_from = source ,
261- convert_to = convert_to ,
262- inputs = inputs ,
263- outputs = outputs ,
264- classifier_config = classifier_config ,
265- ** kwargs
266- )
267-
268- elif source == "pytorch" :
255+ elif exact_source == "pytorch" :
269256 if "example_inputs" in kwargs :
270257 msg = 'Unexpected argument "example_inputs" found'
271258 raise ValueError (msg )
@@ -300,55 +287,81 @@ def _flatten_list(_inputs):
300287 if outputs is not None :
301288 raise ValueError ("outputs must not be specified for PyTorch" )
302289
303- proto_spec = _convert (
304- model ,
305- convert_from = "torch" ,
306- convert_to = convert_to ,
307- inputs = inputs ,
308- outputs = outputs ,
309- classifier_config = classifier_config ,
310- ** kwargs
311- )
312-
313- elif source == "mil" :
290+ elif exact_source == "mil" :
314291 if not isinstance (model , Program ):
315292 msg = "Converter was asked to convert MIL input, but input is not a MIL program!"
316293 raise ValueError (msg )
317294
318- proto_spec = _convert (
319- model ,
320- convert_from = "mil" ,
321- convert_to = convert_to ,
322- example_inputs = inputs ,
323- classifier_config = classifier_config ,
324- ** kwargs
295+
296+ def _determine_source (model , source , outputs ):
297+ """
298+ Infer source (which can be auto) to the precise framework.
299+ """
300+ source = source .lower ()
301+ if source not in {"auto" , "tensorflow" , "pytorch" , "mil" }:
302+ msg = (
303+ 'Unrecognized value of argument "source": {}. '
304+ 'It must be one of ["auto", "tensorflow", "pytorch"].'
325305 )
306+ raise ValueError (msg .format (source ))
326307
327- if convert_to == 'mil' :
328- return proto_spec # Returns the MIL program
329308
330- useCPUOnly = kwargs .get ("useCPUOnly" , True )
331- model = coremltools .models .MLModel (proto_spec , useCPUOnly = useCPUOnly )
309+ # Determine tensorflow version
310+ if source == "tensorflow" and _HAS_TF_2 :
311+ return "tensorflow2"
332312
333- if minimum_deployment_target is not None :
334- check_deployment_compatibility (
335- spec = proto_spec ,
336- representation = convert_to ,
337- deployment_target = minimum_deployment_target ,
338- )
313+ if source != 'auto' :
314+ return source
339315
340- del proto_spec
341- gc .collect ()
316+ # Determine `auto` source
317+ if source == "auto" and _HAS_TF_1 :
318+ try :
319+ loader = TF1Loader (model , outputs = outputs )
320+ loader ._graph_def_from_model (outputs = outputs )
321+ return "tensorflow"
322+ except :
323+ pass
342324
325+ if source == "auto" and _HAS_TF_2 :
326+ try :
327+ loader = TF2Loader (model , outputs = outputs )
328+ loader ._graph_def_from_model (outputs = outputs )
329+ return "tensorflow2"
330+ except :
331+ pass
332+
333+ if source == "auto" and _HAS_TORCH :
334+ try :
335+ pytorch_load (model )
336+ return "pytorch"
337+ except :
338+ pass
339+
340+ if source == "auto" and isinstance (model , Program ):
341+ return "mil"
342+
343+ msg = (
344+ "Unable to determine the type of the model, i.e. the source framework. "
345+ 'Please provide the value of argument "source", from one of '
346+ '["tensorflow", "pytorch", "mil"]. Note that model conversion requires the '
347+ "source package that generates the model. Please make sure you have "
348+ "the appropriate version of source package installed. E.g., if you're "
349+ "converting model originally trained with TensorFlow 1.14, make sure "
350+ "you have `tensorflow==1.14` installed."
351+ )
352+ raise ValueError (msg )
353+
354+
355+ def _record_src_version (mlmodel , exact_source ):
343356 # recording metadata: coremltools version, source framework and version
344- if source in {"tensorflow" , "tensorflow2" } and (_HAS_TF_1 or _HAS_TF_2 ):
357+ if exact_source in {"tensorflow" , "tensorflow2" } and (_HAS_TF_1 or _HAS_TF_2 ):
345358 src_pkg_version = "tensorflow=={0}" .format (tf .__version__ )
346- elif source == "pytorch" and _HAS_TORCH :
359+ elif exact_source == "pytorch" and _HAS_TORCH :
347360 src_pkg_version = "torch=={0}" .format (torch .__version__ )
361+ elif exact_source == 'mil' :
362+ src_pkg_version = "mil"
348363 else :
349- src_pkg_version = "unknown"
350-
351- model .user_defined_metadata [_METADATA_VERSION ] = ct_version
352- model .user_defined_metadata [_METADATA_SOURCE ] = src_pkg_version
364+ raise ValueError ('Unsupported source {}' .format (exact_source ))
353365
354- return model
366+ mlmodel .user_defined_metadata [_METADATA_SOURCE ] = src_pkg_version
367+ return mlmodel
0 commit comments