@@ -96,6 +96,8 @@ def __init__(
9696 super ().__init__ ()
9797
9898 self ._device_index = 0
99+ self ._io = None
100+ self ._model_id = None
99101
100102 if provider_options is not None and "device_id" in provider_options [0 ]:
101103 self ._device_index = provider_options [0 ].get ("device_id" , 0 )
@@ -214,12 +216,12 @@ def _unload(self):
214216 dev_size = axclrt_cffi .new ("uint64_t *" )
215217 dev_prt = axclrt_cffi .new ("void **" )
216218 for i in range (axclrt_lib .axclrtEngineGetNumInputs (self ._info [0 ])):
217- axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io , i , dev_prt , dev_size )
219+ axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
218220 axclrt_lib .axclrtFree (dev_prt [0 ])
219221 for i in range (axclrt_lib .axclrtEngineGetNumOutputs (self ._info [0 ])):
220- axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io , i , dev_prt , dev_size )
222+ axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
221223 axclrt_lib .axclrtFree (dev_prt [0 ])
222- axclrt_lib .axclrtEngineDestroyIO (self ._io )
224+ axclrt_lib .axclrtEngineDestroyIO (self ._io [ 0 ] )
223225 self ._io = None
224226 if self ._model_id [0 ] is not None and self ._model_id [0 ] != 0 :
225227 axclrt_lib .axclrtEngineUnload (self ._model_id [0 ])
@@ -322,7 +324,7 @@ def _prepare_io(self):
322324 ret = axclrt_lib .axclrtEngineSetOutputBufferByIndex (_io [0 ], i , dev_ptr [0 ], max_size )
323325 if 0 != ret :
324326 raise RuntimeError (f"axclrtEngineSetOutputBufferByIndex failed 0x{ ret :08x} for output { i } ." )
325- return _io [ 0 ]
327+ return _io
326328
327329 def run (
328330 self ,
@@ -353,21 +355,21 @@ def run(
353355 if not (npy .flags .c_contiguous or npy .flags .f_contiguous ):
354356 npy = np .ascontiguousarray (npy )
355357 npy_ptr = axclrt_cffi .cast ("void *" , npy .ctypes .data )
356- ret = axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io , i , dev_prt , dev_size )
358+ ret = axclrt_lib .axclrtEngineGetInputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
357359 if 0 != ret :
358360 raise RuntimeError (f"axclrtEngineGetInputBufferByIndex failed for input { i } ." )
359361 ret = axclrt_lib .axclrtMemcpy (dev_prt [0 ], npy_ptr , npy .nbytes , axclrt_lib .AXCL_MEMCPY_HOST_TO_DEVICE )
360362 if 0 != ret :
361363 raise RuntimeError (f"axclrtMemcpy failed for input { i } ." )
362364
363365 # execute model
364- ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], 0 , self ._io )
366+ ret = axclrt_lib .axclrtEngineExecute (self ._model_id [0 ], self ._context_id [0 ], 0 , self ._io [ 0 ] )
365367
366368 # get output
367369 outputs = []
368370 if 0 == ret :
369371 for i in range (len (self .get_outputs ())):
370- ret = axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io , i , dev_prt , dev_size )
372+ ret = axclrt_lib .axclrtEngineGetOutputBufferByIndex (self ._io [ 0 ] , i , dev_prt , dev_size )
371373 if 0 != ret :
372374 raise RuntimeError (f"axclrtEngineGetOutputBufferByIndex failed for output { i } ." )
373375 npy = np .zeros (self .get_outputs ()[i ].shape , dtype = self .get_outputs ()[i ].dtype )
0 commit comments