@@ -330,7 +330,8 @@ def run(
330330 self ,
331331 output_names : list [str ],
332332 input_feed : dict [str , np .ndarray ],
333- run_options = None
333+ run_options = None ,
334+ shape_group : int = 0
334335 ):
335336 self ._validate_input (input_feed )
336337 self ._validate_output (output_names )
@@ -340,13 +341,16 @@ def run(
340341 raise RuntimeError ("axclrtSetCurrentContext failed" )
341342
342343 if None is output_names :
343- output_names = [o .name for o in self .get_outputs ()]
344+ output_names = [o .name for o in self .get_outputs (shape_group )]
345+
346+ if (shape_group > self ._shape_count - 1 ) or (shape_group < 0 ):
347+ raise ValueError (f"Invalid shape group: { shape_group } " )
344348
345349 # fill model io
346350 dev_prt = axclrt_cffi .new ("void **" )
347351 dev_size = axclrt_cffi .new ("uint64_t *" )
348352 for key , npy in input_feed .items ():
349- for i , one in enumerate (self .get_inputs ()):
353+ for i , one in enumerate (self .get_inputs (shape_group )):
350354 if one .name == key :
351355 assert (
352356 list (one .shape ) == list (npy .shape ) and one .dtype == npy .dtype
@@ -363,21 +367,23 @@ def run(
363367 raise RuntimeError (f"axclrtMemcpy failed for input { i } ." )
364368
365369 # execute model
366- ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], 0 , self ._io [0 ])
370+ ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], shape_group , self ._io [0 ])
367371
368372 # get output
369373 outputs = []
370374 if 0 == ret :
371- for i in range (len (self .get_outputs ())):
375+ for i in range (len (self .get_outputs (shape_group ))):
372376 ret = axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io [0 ], i , dev_prt , dev_size )
373377 if 0 != ret :
374378 raise RuntimeError (f"axclrtEngineGetOutputBufferByIndex failed for output { i } ." )
375- npy = np .zeros (self .get_outputs ()[i ].shape , dtype = self .get_outputs ()[i ].dtype )
376- npy_ptr = axclrt_cffi .cast ("void *" , npy .ctypes .data )
377- ret = axclrt_lib .axclrtMemcpy (npy_ptr , dev_prt [0 ], npy .nbytes , axclrt_lib .AXCL_MEMCPY_DEVICE_TO_HOST )
378- if 0 != ret :
379- raise RuntimeError (f"axclrtMemcpy failed for output { i } ." )
380- name = self .get_outputs ()[i ].name
379+ npy_size = self .get_outputs (shape_group )[i ].dtype .itemsize * np .prod (self .get_outputs (shape_group )[i ].shape )
380+ npy = np .frombuffer (
381+ axclrt_cffi .buffer (
382+ self ._io [0 ].pOutputs [i ].pVirAddr , npy_size
383+ ),
384+ dtype = self .get_outputs (shape_group )[i ].dtype ,
385+ ).reshape (self .get_outputs (shape_group )[i ].shape )
386+ name = self .get_outputs (shape_group )[i ].name
381387 if name in output_names :
382388 outputs .append (npy )
383389 return outputs
0 commit comments