Skip to content

Commit 8ab2455

Browse files
ZHEQIUSHUIkalcohol
authored andcommitted
fixbug:axcl run use wrong ptr to make numpy array
1 parent 1312a61 commit 8ab2455

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

axengine/_axclrt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,13 +377,13 @@ def run(
377377
ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io[0], i, dev_prt, dev_size)
378378
if 0 != ret:
379379
raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
380+
buffer_addr = dev_prt[0]
380381
npy_size = self.get_outputs(shape_group)[i].dtype.itemsize * np.prod(self.get_outputs(shape_group)[i].shape)
381-
npy = np.frombuffer(
382-
axclrt_cffi.buffer(
383-
self._io[0].pOutputs[i].pVirAddr, npy_size
384-
),
385-
dtype=self.get_outputs(shape_group)[i].dtype,
386-
).reshape(self.get_outputs(shape_group)[i].shape)
382+
npy = np.zeros(self.get_outputs(shape_group)[i].shape, dtype=self.get_outputs(shape_group)[i].dtype)
383+
npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
384+
ret = axclrt_lib.axclrtMemcpy(npy_ptr, buffer_addr, npy_size, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST)
385+
if 0 != ret:
386+
raise RuntimeError(f"axclrtMemcpy failed for output {i}.")
387387
name = self.get_outputs(shape_group)[i].name
388388
if name in output_names:
389389
outputs.append(npy)

0 commit comments

Comments
 (0)