Skip to content

Commit bd745d1

Browse files
authored
Merge pull request #1217 from konradkusiak97/cachHIPcallsEnqueue
[HIP] Cache some of the HIP driver calls from kernel enqueue
2 parents 76a2a9d + 6b96993 commit bd745d1

File tree

2 files changed

+50
-17
lines changed

2 files changed

+50
-17
lines changed

source/adapters/hip/device.hpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,37 @@ struct ur_device_handle_t_ {
2626
ur_platform_handle_t Platform;
2727
hipCtx_t HIPContext;
2828
uint32_t DeviceIndex;
29+
int MaxWorkGroupSize{0};
30+
int MaxBlockDimX{0};
31+
int MaxBlockDimY{0};
32+
int MaxBlockDimZ{0};
33+
int DeviceMaxLocalMem{0};
34+
int ManagedMemSupport{0};
35+
int ConcurrentManagedAccess{0};
2936

3037
public:
3138
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
3239
ur_platform_handle_t Platform, uint32_t DeviceIndex)
3340
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
34-
HIPContext(Context), DeviceIndex(DeviceIndex) {}
41+
HIPContext(Context), DeviceIndex(DeviceIndex) {
42+
43+
UR_CHECK_ERROR(hipDeviceGetAttribute(
44+
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
45+
UR_CHECK_ERROR(hipDeviceGetAttribute(
46+
&MaxBlockDimX, hipDeviceAttributeMaxBlockDimX, HIPDevice));
47+
UR_CHECK_ERROR(hipDeviceGetAttribute(
48+
&MaxBlockDimY, hipDeviceAttributeMaxBlockDimY, HIPDevice));
49+
UR_CHECK_ERROR(hipDeviceGetAttribute(
50+
&MaxBlockDimZ, hipDeviceAttributeMaxBlockDimZ, HIPDevice));
51+
UR_CHECK_ERROR(hipDeviceGetAttribute(
52+
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
53+
HIPDevice));
54+
UR_CHECK_ERROR(hipDeviceGetAttribute(
55+
&ManagedMemSupport, hipDeviceAttributeManagedMemory, HIPDevice));
56+
UR_CHECK_ERROR(hipDeviceGetAttribute(
57+
&ConcurrentManagedAccess, hipDeviceAttributeConcurrentManagedAccess,
58+
HIPDevice));
59+
}
3560

3661
~ur_device_handle_t_() noexcept(false) {
3762
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
@@ -48,6 +73,22 @@ struct ur_device_handle_t_ {
4873
// Returns the index of the device relative to the other devices in the same
4974
// platform
5075
uint32_t getIndex() const noexcept { return DeviceIndex; };
76+
77+
int getMaxWorkGroupSize() const noexcept { return MaxWorkGroupSize; };
78+
79+
int getMaxBlockDimX() const noexcept { return MaxBlockDimX; };
80+
81+
int getMaxBlockDimY() const noexcept { return MaxBlockDimY; };
82+
83+
int getMaxBlockDimZ() const noexcept { return MaxBlockDimZ; };
84+
85+
int getDeviceMaxLocalMem() const noexcept { return DeviceMaxLocalMem; };
86+
87+
int getManagedMemSupport() const noexcept { return ManagedMemSupport; };
88+
89+
int getConcurrentManagedAccess() const noexcept {
90+
return ConcurrentManagedAccess;
91+
};
5192
};
5293

5394
int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);

source/adapters/hip/enqueue.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
300300
bool ProvidedLocalWorkGroupSize = (pLocalWorkSize != nullptr);
301301

302302
{
303-
ur_result_t Result = urDeviceGetInfo(
304-
hQueue->Device, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
305-
sizeof(MaxThreadsPerBlock), MaxThreadsPerBlock, nullptr);
306-
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
303+
MaxThreadsPerBlock[0] = hQueue->Device->getMaxBlockDimX();
304+
MaxThreadsPerBlock[1] = hQueue->Device->getMaxBlockDimY();
305+
MaxThreadsPerBlock[2] = hQueue->Device->getMaxBlockDimZ();
307306

308-
Result =
309-
urDeviceGetInfo(hQueue->Device, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
310-
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr);
311-
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
307+
MaxWorkGroupSize = hQueue->Device->getMaxWorkGroupSize();
312308

313309
// The MaxWorkGroupSize = 1024 for AMD GPU
314310
// The MaxThreadsPerBlock = {1024, 1024, 1024}
@@ -423,11 +419,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
423419
: (LocalMemSzPtrPI ? LocalMemSzPtrPI : nullptr);
424420

425421
if (LocalMemSzPtr) {
426-
int DeviceMaxLocalMem = 0;
427-
UR_CHECK_ERROR(hipDeviceGetAttribute(
428-
&DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock,
429-
Dev->get()));
430-
422+
int DeviceMaxLocalMem = Dev->getDeviceMaxLocalMem();
431423
static const int EnvVal = std::atoi(LocalMemSzPtr);
432424
if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) {
433425
setErrorMessage(LocalMemSzPtrUR ? "Invalid value specified for "
@@ -1484,7 +1476,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
14841476

14851477
// If the device does not support managed memory access, we can't set
14861478
// mem_advise.
1487-
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1479+
if (!Device->getManagedMemSupport()) {
14881480
releaseEvent();
14891481
setErrorMessage("mem_advise ignored as device does not support "
14901482
"managed memory access",
@@ -1558,7 +1550,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
15581550

15591551
// If the device does not support managed memory access, we can't set
15601552
// mem_advise.
1561-
if (!getAttribute(Device, hipDeviceAttributeManagedMemory)) {
1553+
if (!Device->getManagedMemSupport()) {
15621554
releaseEvent();
15631555
setErrorMessage("mem_advise ignored as device does not support "
15641556
"managed memory access",
@@ -1575,7 +1567,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
15751567
UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE |
15761568
UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE |
15771569
UR_USM_ADVICE_FLAG_DEFAULT)) {
1578-
if (!getAttribute(Device, hipDeviceAttributeConcurrentManagedAccess)) {
1570+
if (!Device->getConcurrentManagedAccess()) {
15791571
releaseEvent();
15801572
setErrorMessage("mem_advise ignored as device does not support "
15811573
"concurrent managed access",

0 commit comments

Comments
 (0)