@@ -63,7 +63,7 @@ def get_model_info(model, X, y=None):
63
63
64
64
# Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
65
65
# name alone is not sufficient.
66
- elif torch and isinstance (model , torch .nn .Module ):
66
+ if torch and isinstance (model , torch .nn .Module ):
67
67
return PyTorchModelInfo (model , X , y )
68
68
69
69
raise ValueError (f"Unrecognized model type { type (model )} received." )
@@ -200,7 +200,8 @@ class OnnxModelInfo(ModelInfo):
200
200
def __init__ (self , model , X , y = None ):
201
201
if onnx is None :
202
202
raise RuntimeError (
203
- "The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
203
+ "The onnx package must be installed to work with ONNX models. "
204
+ "Please `pip install onnx`."
204
205
)
205
206
206
207
self ._model = model
@@ -214,38 +215,19 @@ def __init__(self, model, X, y=None):
214
215
215
216
if len (inputs ) > 1 :
216
217
warnings .warn (
217
- f"The ONNX model has { len (inputs )} inputs but only the first input will be captured in Model Manager."
218
+ f"The ONNX model has { len (inputs )} inputs but only the first input "
219
+ f"will be captured in Model Manager."
218
220
)
219
221
220
222
if len (outputs ) > 1 :
221
223
warnings .warn (
222
- f"The ONNX model has { len (outputs )} outputs but only the first input will be captured in Model Manager."
224
+ f"The ONNX model has { len (outputs )} outputs but only the first output "
225
+ f"will be captured in Model Manager."
223
226
)
224
227
225
228
self ._X_df = inputs [0 ]
226
229
self ._y_df = outputs [0 ]
227
230
228
- # initializer (static params)
229
-
230
- # for field in model.ListFields():
231
- # doc_string
232
- # domain
233
- # metadata_props
234
- # model_author
235
- # model_license
236
- # model_version
237
- # producer_name
238
- # producer_version
239
- # training_info
240
-
241
- # irVersion
242
- # producerName
243
- # producerVersion
244
- # opsetImport
245
-
246
- # # list of (FieldDescriptor, value)
247
- # fields = model.ListFields()
248
-
249
231
@staticmethod
250
232
def _tensor_to_dataframe (tensor ):
251
233
"""
@@ -272,7 +254,7 @@ def _tensor_to_dataframe(tensor):
272
254
name = tensor .get ("name" , "Var" )
273
255
type_ = tensor ["type" ]
274
256
275
- if not "tensorType" in type_ :
257
+ if "tensorType" not in type_ :
276
258
raise ValueError (f"Received an unexpected ONNX input type: { type_ } ." )
277
259
278
260
dtype = onnx .helper .tensor_dtype_to_np_dtype (type_ ["tensorType" ]["elemType" ])
@@ -374,8 +356,6 @@ def __init__(self, model, X, y=None):
374
356
raise ValueError (
375
357
f"Expected input data to be a numpy array or PyTorch tensor, received { type (X )} ."
376
358
)
377
- # if X.ndim != 2:
378
- # raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")
379
359
380
360
# Ensure each input is a PyTorch Tensor
381
361
X = tuple (x if isinstance (x , torch .Tensor ) else torch .tensor (x ) for x in X )
@@ -395,8 +375,6 @@ def __init__(self, model, X, y=None):
395
375
)
396
376
397
377
self ._model = model
398
-
399
- # TODO: convert X and y to DF with arbitrary names
400
378
self ._X = X
401
379
self ._y = y
402
380
0 commit comments