diff --git a/sycl/include/sycl/ext/oneapi/experimental/enqueue_functions.hpp b/sycl/include/sycl/ext/oneapi/experimental/enqueue_functions.hpp index f599078a6769e..36f14b23845c1 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/enqueue_functions.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/enqueue_functions.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -369,15 +370,17 @@ void fill(sycl::queue Q, T *Ptr, const T &Pattern, size_t Count, CodeLoc); } -inline void prefetch(handler &CGH, void *Ptr, size_t NumBytes) { - CGH.prefetch(Ptr, NumBytes); +inline void prefetch(handler &CGH, void *Ptr, size_t NumBytes, + prefetch_type Type = prefetch_type::device) { + CGH.prefetch(Ptr, NumBytes, Type); } inline void prefetch(queue Q, void *Ptr, size_t NumBytes, + prefetch_type Type = prefetch_type::device, const sycl::detail::code_location &CodeLoc = sycl::detail::code_location::current()) { submit( - std::move(Q), [&](handler &CGH) { prefetch(CGH, Ptr, NumBytes); }, + std::move(Q), [&](handler &CGH) { prefetch(CGH, Ptr, NumBytes, Type); }, CodeLoc); } diff --git a/sycl/include/sycl/ext/oneapi/experimental/enqueue_types.hpp b/sycl/include/sycl/ext/oneapi/experimental/enqueue_types.hpp new file mode 100644 index 0000000000000..dacd45126a7fb --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/experimental/enqueue_types.hpp @@ -0,0 +1,33 @@ +//==--------------- enqueue_types.hpp ---- SYCL enqueue types --------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace sycl { +inline namespace _V1 { +namespace ext::oneapi::experimental { + +/// @brief Indicates the destination device for USM data to be prefetched to. +enum class prefetch_type { device, host }; + +inline std::string prefetchTypeToString(prefetch_type value) { + switch (value) { + case sycl::ext::oneapi::experimental::prefetch_type::device: + return "prefetch_type::device"; + case sycl::ext::oneapi::experimental::prefetch_type::host: + return "prefetch_type::host"; + default: + return "prefetch_type::unknown"; + } +} + +} // namespace ext::oneapi::experimental +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index bd91b9dc755ec..cf44c59d08354 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -149,6 +149,8 @@ namespace ext ::oneapi ::experimental { template class work_group_memory; template class dynamic_work_group_memory; struct image_descriptor; +enum class prefetch_type; + __SYCL_EXPORT void async_free(sycl::handler &h, void *ptr); __SYCL_EXPORT void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size); @@ -2627,6 +2629,16 @@ class __SYCL_EXPORT handler { /// \param Count is a number of bytes to be prefetched. void prefetch(const void *Ptr, size_t Count); + /// Provides hints to the runtime library that data should be made available + /// on a device earlier than Unified Shared Memory would normally require it + /// to be available. + /// + /// \param Ptr is a USM pointer to the memory to be prefetched to the device. + /// \param Count is a number of bytes to be prefetched. + /// \param Type is type of prefetch, i.e. fetch to device or fetch to host. + void prefetch(const void *Ptr, size_t Count, + ext::oneapi::experimental::prefetch_type Type); + /// Provides additional information to the underlying runtime about how /// different allocations are used. /// diff --git a/sycl/source/detail/cg.hpp b/sycl/source/detail/cg.hpp index 2eb44926f7381..1861c64ad4bc1 100644 --- a/sycl/source/detail/cg.hpp +++ b/sycl/source/detail/cg.hpp @@ -398,14 +398,19 @@ class CGFillUSM : public CG { class CGPrefetchUSM : public CG { void *MDst; size_t MLength; + ext::oneapi::experimental::prefetch_type MPrefetchType; public: CGPrefetchUSM(void *DstPtr, size_t Length, CG::StorageInitHelper CGData, + ext::oneapi::experimental::prefetch_type PrefetchType, detail::code_location loc = {}) : CG(CGType::PrefetchUSM, std::move(CGData), std::move(loc)), - MDst(DstPtr), MLength(Length) {} - void *getDst() { return MDst; } - size_t getLength() { return MLength; } + MDst(DstPtr), MLength(Length), MPrefetchType(PrefetchType) {} + void *getDst() const { return MDst; } + size_t getLength() const { return MLength; } + ext::oneapi::experimental::prefetch_type getPrefetchType() const { + return MPrefetchType; + } }; /// "Advise USM" command group class. diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index bfcdb18f63a4f..fdcae10a5ceb4 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -15,7 +15,8 @@ #include // for CGType #include // for kernel_param_kind_t -#include // for node +#include // for prefetchType +#include // for node #include #include @@ -655,7 +656,10 @@ class node_impl : public std::enable_shared_from_this { sycl::detail::CGPrefetchUSM *Prefetch = static_cast(MCommandGroup.get()); Stream << "Dst: " << Prefetch->getDst() - << " Length: " << Prefetch->getLength() << "\\n"; + << " Length: " << Prefetch->getLength() << " PrefetchType: " + << sycl::ext::oneapi::experimental::prefetchTypeToString( + Prefetch->getPrefetchType()) + << "\\n"; } break; case sycl::detail::CGType::AdviseUSM: diff --git a/sycl/source/detail/handler_impl.hpp b/sycl/source/detail/handler_impl.hpp index 0fda3dd4f2769..23ce36d691dc2 100644 --- a/sycl/source/detail/handler_impl.hpp +++ b/sycl/source/detail/handler_impl.hpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace sycl { inline namespace _V1 { @@ -91,6 +92,10 @@ class handler_impl { /// property. bool MIsDeviceImageScoped = false; + /// Direction of USM prefetch / destination device. + sycl::ext::oneapi::experimental::prefetch_type MPrefetchType = + sycl::ext::oneapi::experimental::prefetch_type::device; + // Program scope pipe information. // Pipe name that uniquely identifies a pipe. diff --git a/sycl/source/detail/memory_manager.cpp b/sycl/source/detail/memory_manager.cpp index 9903278e807bc..e09969fba057c 100644 --- a/sycl/source/detail/memory_manager.cpp +++ b/sycl/source/detail/memory_manager.cpp @@ -925,13 +925,18 @@ void MemoryManager::fill_usm(void *Mem, queue_impl &Queue, size_t Length, DepEvents.size(), DepEvents.data(), OutEvent); } -void MemoryManager::prefetch_usm(void *Mem, queue_impl &Queue, size_t Length, - std::vector DepEvents, - ur_event_handle_t *OutEvent) { +void MemoryManager::prefetch_usm( + void *Mem, queue_impl &Queue, size_t Length, + std::vector DepEvents, ur_event_handle_t *OutEvent, + sycl::ext::oneapi::experimental::prefetch_type Dest) { adapter_impl &Adapter = Queue.getAdapter(); - Adapter.call(Queue.getHandleRef(), Mem, - Length, 0u, DepEvents.size(), - DepEvents.data(), OutEvent); + ur_usm_migration_flags_t MigrationFlag = + (Dest == sycl::ext::oneapi::experimental::prefetch_type::device) + ? UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE + : UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST; + Adapter.call( + Queue.getHandleRef(), Mem, Length, MigrationFlag, DepEvents.size(), + DepEvents.data(), OutEvent); } void MemoryManager::advise_usm(const void *Mem, queue_impl &Queue, @@ -1542,11 +1547,16 @@ void MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer( sycl::detail::context_impl *Context, ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length, std::vector Deps, - ur_exp_command_buffer_sync_point_t *OutSyncPoint) { + ur_exp_command_buffer_sync_point_t *OutSyncPoint, + sycl::ext::oneapi::experimental::prefetch_type Dest) { adapter_impl &Adapter = Context->getAdapter(); + ur_usm_migration_flags_t MigrationFlag = + (Dest == sycl::ext::oneapi::experimental::prefetch_type::device) + ? UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE + : UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST; Adapter.call( - CommandBuffer, Mem, Length, ur_usm_migration_flags_t(0), Deps.size(), - Deps.data(), 0u, nullptr, OutSyncPoint, nullptr, nullptr); + CommandBuffer, Mem, Length, MigrationFlag, Deps.size(), Deps.data(), 0, + nullptr, OutSyncPoint, nullptr, nullptr); } void MemoryManager::ext_oneapi_advise_usm_cmd_buffer( diff --git a/sycl/source/detail/memory_manager.hpp b/sycl/source/detail/memory_manager.hpp index 7c87e9252d157..02b0c7d673433 100644 --- a/sycl/source/detail/memory_manager.hpp +++ b/sycl/source/detail/memory_manager.hpp @@ -11,6 +11,7 @@ #include #include #include +#include // for prefetch_type #include #include #include @@ -146,9 +147,12 @@ class MemoryManager { std::vector DepEvents, ur_event_handle_t *OutEvent); - static void prefetch_usm(void *Ptr, queue_impl &Queue, size_t Len, - std::vector DepEvents, - ur_event_handle_t *OutEvent); + static void + prefetch_usm(void *Ptr, queue_impl &Queue, size_t Len, + std::vector DepEvents, + ur_event_handle_t *OutEvent, + sycl::ext::oneapi::experimental::prefetch_type Dest = + sycl::ext::oneapi::experimental::prefetch_type::device); static void advise_usm(const void *Ptr, queue_impl &Queue, size_t Len, ur_usm_advice_flags_t Advice, @@ -245,7 +249,9 @@ class MemoryManager { sycl::detail::context_impl *Context, ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length, std::vector Deps, - ur_exp_command_buffer_sync_point_t *OutSyncPoint); + ur_exp_command_buffer_sync_point_t *OutSyncPoint, + sycl::ext::oneapi::experimental::prefetch_type Dest = + sycl::ext::oneapi::experimental::prefetch_type::device); static void ext_oneapi_advise_usm_cmd_buffer( sycl::detail::context_impl *Context, diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 70f12d0a59ef7..69f1a724c0c12 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -3087,7 +3087,8 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() { if (auto Result = callMemOpHelper( MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer, &MQueue->getContextImpl(), MCommandBuffer, Prefetch->getDst(), - Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint); + Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint, + Prefetch->getPrefetchType()); Result != UR_RESULT_SUCCESS) return Result; @@ -3398,7 +3399,8 @@ ur_result_t ExecCGCommand::enqueueImpQueue() { CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get(); if (auto Result = callMemOpHelper( MemoryManager::prefetch_usm, Prefetch->getDst(), *MQueue, - Prefetch->getLength(), std::move(RawEvents), Event); + Prefetch->getLength(), std::move(RawEvents), Event, + Prefetch->getPrefetchType()); Result != UR_RESULT_SUCCESS) return Result; diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 33f50af815a5b..f0ecf11affc99 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -36,6 +36,7 @@ #include #include +#include #include #include #include @@ -719,8 +720,9 @@ event handler::finalize() { MCodeLoc)); break; case detail::CGType::PrefetchUSM: - CommandGroup.reset(new detail::CGPrefetchUSM( - MDstPtr, MLength, std::move(impl->CGData), MCodeLoc)); + CommandGroup.reset( + new detail::CGPrefetchUSM(MDstPtr, MLength, std::move(impl->CGData), + impl->MPrefetchType, MCodeLoc)); break; case detail::CGType::AdviseUSM: CommandGroup.reset(new detail::CGAdviseUSM(MDstPtr, MLength, impl->MAdvice, @@ -1473,6 +1475,16 @@ void handler::prefetch(const void *Ptr, size_t Count) { throwIfActionIsCreated(); MDstPtr = const_cast(Ptr); MLength = Count; + impl->MPrefetchType = ext::oneapi::experimental::prefetch_type::device; + setType(detail::CGType::PrefetchUSM); +} + +void handler::prefetch(const void *Ptr, size_t Count, + ext::oneapi::experimental::prefetch_type Type) { + throwIfActionIsCreated(); + MDstPtr = const_cast(Ptr); + MLength = Count; + impl->MPrefetchType = Type; setType(detail::CGType::PrefetchUSM); } diff --git a/sycl/test-e2e/Graph/RecordReplay/ext_oneapi_enqueue_functions_prefetch.cpp b/sycl/test-e2e/Graph/RecordReplay/ext_oneapi_enqueue_functions_prefetch.cpp new file mode 100644 index 0000000000000..d998f6d9a21c6 --- /dev/null +++ b/sycl/test-e2e/Graph/RecordReplay/ext_oneapi_enqueue_functions_prefetch.cpp @@ -0,0 +1,88 @@ +// REQUIRES: aspect-usm_shared_allocations +// +// RUN: %{build} -o %t.out +// RUN: %{run} %t.out +// +// UNSUPPORTED: opencl +// UNSUPPORTED-INTENDED: OpenCL currently has limited support for command +// buffers +// +// RUN: %if level_zero %{%{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %} + +// Tests prefetch functionality in enqueue functions + +#include "../graph_common.hpp" +#include + +static constexpr int N = 100; +static constexpr int Pattern = 42; + +int main() { + queue Q{}; + + int *Src = malloc_shared(N, Q); + int *Dst = malloc_shared(N, Q); + for (int i = 0; i < N; i++) + Src[i] = Pattern; + + { + exp_ext::command_graph Graph{Q.get_context(), Q.get_device(), {}}; + + Graph.begin_recording(Q); + + // Test submitting host-to-device prefetch + event TestH2D = exp_ext::submit_with_event( + Q, [&](handler &CGH) { exp_ext::prefetch(CGH, Src, sizeof(int) * N); }); + + exp_ext::submit(Q, [&](handler &CGH) { + CGH.depends_on(TestH2D); + exp_ext::parallel_for(CGH, range<1>(N), + [=](id<1> i) { Dst[i] = Src[i] * 2; }); + }); + + Graph.end_recording(); + + auto GraphExec = Graph.finalize(); + + exp_ext::execute_graph(Q, GraphExec); + Q.wait_and_throw(); + } + + // Check host-to-device prefetch results + for (int i = 0; i < N; i++) + assert(check_value(i, Pattern * 2, Dst[i], "Dst")); + + { + exp_ext::command_graph Graph{Q.get_context(), Q.get_device(), {}}; + + Graph.begin_recording(Q); + + // Test submitting device-to-host prefetch + event TestD2H = exp_ext::submit_with_event(Q, [&](handler &CGH) { + exp_ext::parallel_for(CGH, range<1>(N), + [=](id<1> i) { Dst[i] = Src[i] + 1; }); + }); + + exp_ext::submit(Q, [&](handler &CGH) { + CGH.depends_on(TestD2H); + exp_ext::prefetch(CGH, Dst, sizeof(int) * N, + exp_ext::prefetch_type::host); + }); + + Graph.end_recording(); + + auto GraphExec = Graph.finalize(); + + exp_ext::execute_graph(Q, GraphExec); + Q.wait_and_throw(); + } + + // Check device-to-host prefetch results + for (int i = 0; i < N; i++) + assert(check_value(i, Pattern + 1, Dst[i], "Dst")); + + free(Src, Q); + free(Dst, Q); + + return 0; +} diff --git a/sycl/test-e2e/USM/prefetch_exp.cpp b/sycl/test-e2e/USM/prefetch_exp.cpp new file mode 100644 index 0000000000000..56fa19c527814 --- /dev/null +++ b/sycl/test-e2e/USM/prefetch_exp.cpp @@ -0,0 +1,111 @@ +//==-------- prefetch_exp.cpp - Experimental 2-way USM prefetch test -------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// REQUIRES: aspect-usm_shared_allocations +// +// RUN: %{build} -o %t1.out +// RUN: %{run} %t1.out + +#include +#include +#include + +using namespace sycl; + +static constexpr int Count = 100; + +int main() { + queue q([](exception_list el) { + for (auto &e : el) + throw e; + }); + + float *Src = malloc_shared(Count, q); + float *Dest = malloc_shared(Count, q); + for (int i = 0; i < Count; i++) + Src[i] = i; + + { + // Test host-to-device prefetch via prefetch(handler ...). + event InitPrefetch = + ext::oneapi::experimental::submit_with_event(q, [&](handler &CGH) { + ext::oneapi::experimental::prefetch(CGH, Src, sizeof(float) * Count); + }); + + q.submit([&](handler &CGH) { + CGH.depends_on(InitPrefetch); + CGH.single_task([=]() { + for (int i = 0; i < Count; i++) + Dest[i] = 2 * Src[i]; + }); + }); + q.wait_and_throw(); + + for (int i = 0; i < Count; i++) { + assert(Dest[i] == i * 2); + } + + // Test device-to-host prefetch via prefetch(handler ...). + event InitPrefetchBack = q.submit([&](handler &CGH) { + CGH.single_task([=]() { + for (int i = 0; i < Count; i++) + Dest[i] = 4 * Src[i]; + }); + }); + + ext::oneapi::experimental::submit(q, [&](handler &CGH) { + CGH.depends_on(InitPrefetchBack); + ext::oneapi::experimental::prefetch( + CGH, Dest, sizeof(float) * Count, + ext::oneapi::experimental::prefetch_type::host); + }); + q.wait_and_throw(); + + for (int i = 0; i < Count; i++) { + assert(Dest[i] == i * 4); + } + } + + { + // Test host-to-device prefetch via prefetch(queue ...). + ext::oneapi::experimental::prefetch( + q, Src, sizeof(float) * Count, + ext::oneapi::experimental::prefetch_type::device); + q.wait_and_throw(); + q.submit([&](handler &CGH) { + CGH.single_task([=]() { + for (int i = 0; i < Count; i++) + Dest[i] = 3 * Src[i]; + }); + }); + q.wait_and_throw(); + + for (int i = 0; i < Count; i++) { + assert(Dest[i] == i * 3); + } + + // Test device-to-host prefetch via prefetch(queue ...). + q.submit([&](handler &CGH) { + CGH.single_task([=]() { + for (int i = 0; i < Count; i++) + Dest[i] = 6 * Src[i]; + }); + }); + q.wait_and_throw(); + ext::oneapi::experimental::prefetch( + q, Src, sizeof(float) * Count, + ext::oneapi::experimental::prefetch_type::host); + q.wait_and_throw(); + + for (int i = 0; i < Count; i++) { + assert(Dest[i] == i * 6); + } + } + free(Src, q); + free(Dest, q); +} diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index dd392cf315b88..62d363b8d4189 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -3632,6 +3632,7 @@ _ZN4sycl3_V17handler7setTypeENS0_6detail6CGTypeE _ZN4sycl3_V17handler8finalizeEv _ZN4sycl3_V17handler8getQueueEv _ZN4sycl3_V17handler8prefetchEPKvm +_ZN4sycl3_V17handler8prefetchEPKvmNS0_3ext6oneapi12experimental13prefetch_typeE _ZN4sycl3_V17handler9clearArgsEv _ZN4sycl3_V17handler9fill_implEPvPKvmm _ZN4sycl3_V17handlerC1EOSt10unique_ptrINS0_6detail12handler_implESt14default_deleteIS4_EE diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index 748d74482a1a5..2bf236e2a5d35 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -4358,6 +4358,7 @@ ?postProcess@HandlerAccess@detail@_V1@sycl@@SAXAEAVhandler@34@Vtype_erased_cgfo_ty@234@@Z ?preProcess@HandlerAccess@detail@_V1@sycl@@SAXAEAVhandler@34@Vtype_erased_cgfo_ty@234@@Z ?prefetch@handler@_V1@sycl@@QEAAXPEBX_K@Z +?prefetch@handler@_V1@sycl@@QEAAXPEBX_KW4prefetch_type@experimental@oneapi@ext@23@@Z ?prefetch@queue@_V1@sycl@@QEAA?AVevent@23@PEBX_KAEBUcode_location@detail@23@@Z ?prefetch@queue@_V1@sycl@@QEAA?AVevent@23@PEBX_KAEBV?$vector@Vevent@_V1@sycl@@V?$allocator@Vevent@_V1@sycl@@@std@@@std@@AEBUcode_location@detail@23@@Z ?prefetch@queue@_V1@sycl@@QEAA?AVevent@23@PEBX_KV423@AEBUcode_location@detail@23@@Z diff --git a/sycl/unittests/Extensions/CMakeLists.txt b/sycl/unittests/Extensions/CMakeLists.txt index b82c9f798a94c..59d57f0851ec1 100644 --- a/sycl/unittests/Extensions/CMakeLists.txt +++ b/sycl/unittests/Extensions/CMakeLists.txt @@ -12,6 +12,7 @@ add_sycl_unittest(ExtensionsTests OBJECT CompositeDevice.cpp OneAPIProd.cpp EnqueueFunctionsEvents.cpp + EnqueueFunctionsPrefetch.cpp ProfilingTag.cpp KernelProperties.cpp NoDeviceIPVersion.cpp @@ -22,6 +23,7 @@ add_sycl_unittest(ExtensionsTests OBJECT EventMode.cpp DeviceInfo.cpp RootGroup.cpp + USMPrefetch.cpp ) add_subdirectory(CommandGraph) diff --git a/sycl/unittests/Extensions/EnqueueFunctionsPrefetch.cpp b/sycl/unittests/Extensions/EnqueueFunctionsPrefetch.cpp new file mode 100644 index 0000000000000..18eac0a916bd4 --- /dev/null +++ b/sycl/unittests/Extensions/EnqueueFunctionsPrefetch.cpp @@ -0,0 +1,80 @@ +//==------------------- EnqueueFunctionsPrefetch.cpp -----------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Tests enqueue_functions prefetch calls UR functions with the right arguments. + +#include + +#include +#include +#include +#include +#include +#include + +using namespace sycl; + +namespace oneapiext = ext::oneapi::experimental; + +namespace { + +static ur_usm_migration_flags_t SubmittedPrefetchType = + UR_USM_MIGRATION_FLAG_FORCE_UINT32; + +inline ur_result_t replace_urUSMEnqueuePrefetch(void *pParams) { + auto params = *static_cast(pParams); + SubmittedPrefetchType = *params.pflags; + return UR_RESULT_SUCCESS; +} + +static constexpr size_t N = 1024; +class EnqueueFunctionsPrefetchTests : public ::testing::Test { +public: + EnqueueFunctionsPrefetchTests() + : Mock{}, Q{context(sycl::platform()), default_selector_v, + property::queue::in_order{}} {} + +protected: + void SetUp() override { + SubmittedPrefetchType = UR_USM_MIGRATION_FLAG_FORCE_UINT32; + Dst = malloc_shared(N, Q); + } + + unittest::UrMock<> Mock; + queue Q; + int *Dst; +}; + +TEST_F(EnqueueFunctionsPrefetchTests, SubmitHostToDevicePrefetch) { + mock::getCallbacks().set_replace_callback("urEnqueueUSMPrefetch", + replace_urUSMEnqueuePrefetch); + + oneapiext::submit(Q, [&](handler &CGH) { + oneapiext::prefetch(CGH, Dst, sizeof(int) * N, + oneapiext::prefetch_type::device); + }); + + ASSERT_EQ(SubmittedPrefetchType, UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE); + + free(Dst, Q); +} + +TEST_F(EnqueueFunctionsPrefetchTests, SubmitDeviceToHostPrefetch) { + mock::getCallbacks().set_replace_callback("urEnqueueUSMPrefetch", + replace_urUSMEnqueuePrefetch); + + oneapiext::submit(Q, [&](handler &CGH) { + oneapiext::prefetch(CGH, Dst, sizeof(int) * N, + oneapiext::prefetch_type::host); + }); + + ASSERT_EQ(SubmittedPrefetchType, UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST); + + free(Dst, Q); +} + +} // namespace diff --git a/sycl/unittests/Extensions/USMPrefetch.cpp b/sycl/unittests/Extensions/USMPrefetch.cpp new file mode 100644 index 0000000000000..512eb638d51ef --- /dev/null +++ b/sycl/unittests/Extensions/USMPrefetch.cpp @@ -0,0 +1,69 @@ +//==------------------------- USMPrefetch.cpp ------------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Test SYCL prefetch calls UR prefetch functions with the correct arguments. + +#include + +#include +#include +#include +#include +#include + +using namespace sycl; + +namespace { + +static ur_usm_migration_flags_t SubmittedPrefetchType = + UR_USM_MIGRATION_FLAG_FORCE_UINT32; + +inline ur_result_t replace_urUSMEnqueuePrefetch(void *pParams) { + auto params = *static_cast(pParams); + SubmittedPrefetchType = *params.pflags; + return UR_RESULT_SUCCESS; +} + +static constexpr size_t N = 1024; +class USMPrefetchTests : public ::testing::Test { +public: + USMPrefetchTests() + : Mock{}, Q{context(sycl::platform()), default_selector_v, + property::queue::in_order{}} {} + +protected: + void SetUp() override { + SubmittedPrefetchType = UR_USM_MIGRATION_FLAG_FORCE_UINT32; + Dst = malloc_shared(N, Q); + } + + unittest::UrMock<> Mock; + queue Q; + int *Dst; +}; + +TEST_F(USMPrefetchTests, QueuePrefetch) { + mock::getCallbacks().set_replace_callback("urEnqueueUSMPrefetch", + replace_urUSMEnqueuePrefetch); + + Q.prefetch(Dst, sizeof(int) * N); + ASSERT_EQ(SubmittedPrefetchType, UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE); + + free(Dst, Q); +} + +TEST_F(USMPrefetchTests, HandlerPrefetch) { + mock::getCallbacks().set_replace_callback("urEnqueueUSMPrefetch", + replace_urUSMEnqueuePrefetch); + + Q.submit([&](handler &CGH) { CGH.prefetch(Dst, sizeof(int) * N); }); + ASSERT_EQ(SubmittedPrefetchType, UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE); + + free(Dst, Q); +} + +} // namespace diff --git a/unified-runtime/include/ur_api.h b/unified-runtime/include/ur_api.h index 5882015c55972..7f236c89fc7de 100644 --- a/unified-runtime/include/ur_api.h +++ b/unified-runtime/include/ur_api.h @@ -8620,18 +8620,20 @@ typedef enum ur_map_flag_t { #define UR_MAP_FLAGS_MASK 0xfffffff8 /////////////////////////////////////////////////////////////////////////////// -/// @brief Map flags +/// @brief USM migration flags, indicating the direction data is migrated in typedef uint32_t ur_usm_migration_flags_t; typedef enum ur_usm_migration_flag_t { - /// Default migration TODO: Add more enums! - UR_USM_MIGRATION_FLAG_DEFAULT = UR_BIT(0), + /// Migrate data from host to device + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE = UR_BIT(0), + /// Migrate data from device to host + UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST = UR_BIT(1), /// @cond UR_USM_MIGRATION_FLAG_FORCE_UINT32 = 0x7fffffff /// @endcond } ur_usm_migration_flag_t; /// @brief Bit Mask for validating ur_usm_migration_flags_t -#define UR_USM_MIGRATION_FLAGS_MASK 0xfffffffe +#define UR_USM_MIGRATION_FLAGS_MASK 0xfffffffc /////////////////////////////////////////////////////////////////////////////// /// @brief Enqueue a command to map a region of the buffer object into the host @@ -11897,7 +11899,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/include/ur_print.hpp b/unified-runtime/include/ur_print.hpp index 93cc0d5f2b6fb..216025e665be8 100644 --- a/unified-runtime/include/ur_print.hpp +++ b/unified-runtime/include/ur_print.hpp @@ -11051,8 +11051,11 @@ inline ur_result_t printFlag(std::ostream &os, uint32_t flag) { inline std::ostream &operator<<(std::ostream &os, enum ur_usm_migration_flag_t value) { switch (value) { - case UR_USM_MIGRATION_FLAG_DEFAULT: - os << "UR_USM_MIGRATION_FLAG_DEFAULT"; + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + os << "UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE"; + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + os << "UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST"; break; default: os << "unknown enumerator"; @@ -11070,15 +11073,26 @@ inline ur_result_t printFlag(std::ostream &os, uint32_t val = flag; bool first = true; - if ((val & UR_USM_MIGRATION_FLAG_DEFAULT) == - (uint32_t)UR_USM_MIGRATION_FLAG_DEFAULT) { - val ^= (uint32_t)UR_USM_MIGRATION_FLAG_DEFAULT; + if ((val & UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) == + (uint32_t)UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + val ^= (uint32_t)UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE; + if (!first) { + os << " | "; + } else { + first = false; + } + os << UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE; + } + + if ((val & UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST) == + (uint32_t)UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST) { + val ^= (uint32_t)UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST; if (!first) { os << " | "; } else { first = false; } - os << UR_USM_MIGRATION_FLAG_DEFAULT; + os << UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST; } if (val != 0) { std::bitset<32> bits(val); diff --git a/unified-runtime/scripts/core/enqueue.yml b/unified-runtime/scripts/core/enqueue.yml index 20d7d7bc2ab3f..a6148bd366f64 100644 --- a/unified-runtime/scripts/core/enqueue.yml +++ b/unified-runtime/scripts/core/enqueue.yml @@ -915,13 +915,16 @@ etors: value: "$X_BIT(2)" --- #-------------------------------------------------------------------------- type: enum -desc: "Map flags" -class: $xDevice +desc: "USM migration flags, indicating the direction data is migrated in" +class: $xEnqueue name: $x_usm_migration_flags_t etors: - - name: DEFAULT - desc: "Default migration TODO: Add more enums! " + - name: HOST_TO_DEVICE + desc: "Migrate data from host to device" value: "$X_BIT(0)" + - name: DEVICE_TO_HOST + desc: "Migrate data from device to host" + value: "$X_BIT(1)" --- #-------------------------------------------------------------------------- type: function desc: "Enqueue a command to map a region of the buffer object into the host address space and return a pointer to the mapped region" diff --git a/unified-runtime/scripts/core/exp-command-buffer.yml b/unified-runtime/scripts/core/exp-command-buffer.yml index e8f2caa15d59d..a194777f9e40b 100644 --- a/unified-runtime/scripts/core/exp-command-buffer.yml +++ b/unified-runtime/scripts/core/exp-command-buffer.yml @@ -1025,7 +1025,7 @@ params: desc: "[in] size in bytes to be fetched." - type: $x_usm_migration_flags_t name: flags - desc: "[in] USM prefetch flags" + desc: "[in] USM migration flags" - type: uint32_t name: numSyncPointsInWaitList desc: "[in] The number of sync points in the provided dependency list." diff --git a/unified-runtime/source/adapters/cuda/enqueue.cpp b/unified-runtime/source/adapters/cuda/enqueue.cpp index 091e8e9d53d44..d308638aa0caf 100644 --- a/unified-runtime/source/adapters/cuda/enqueue.cpp +++ b/unified-runtime/source/adapters/cuda/enqueue.cpp @@ -1558,14 +1558,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, const void *pMem, size_t size, - ur_usm_migration_flags_t /*flags*/, uint32_t numEventsInWaitList, + ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + ur_device_handle_t Device = hQueue->getDevice(); + int dstDevice; + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + dstDevice = Device->get(); + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + dstDevice = CU_DEVICE_CPU; + break; + default: + setErrorMessage("Invalid USM migration flag", + UR_RESULT_ERROR_INVALID_ENUMERATION); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + size_t PointerRangeSize = 0; UR_CHECK_ERROR(cuPointerGetAttribute( &PointerRangeSize, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)pMem)); UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE); - ur_device_handle_t Device = hQueue->getDevice(); std::unique_ptr EventPtr{nullptr}; try { @@ -1606,7 +1620,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( } UR_CHECK_ERROR( - cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream)); + cuMemPrefetchAsync((CUdeviceptr)pMem, size, dstDevice, CuStream)); } catch (ur_result_t Err) { return Err; } diff --git a/unified-runtime/source/adapters/hip/enqueue.cpp b/unified-runtime/source/adapters/hip/enqueue.cpp index 54ea1ca91a71d..89b45d9d29b2a 100644 --- a/unified-runtime/source/adapters/hip/enqueue.cpp +++ b/unified-runtime/source/adapters/hip/enqueue.cpp @@ -1379,11 +1379,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy( UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( ur_queue_handle_t hQueue, const void *pMem, size_t size, - ur_usm_migration_flags_t /*flags*/, uint32_t numEventsInWaitList, + ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - void *HIPDevicePtr = const_cast(pMem); ur_device_handle_t Device = hQueue->getDevice(); + hipDevice_t TargetDevice; + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + TargetDevice = Device->get(); + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + TargetDevice = hipCpuDeviceId; + break; + default: + setErrorMessage("Invalid USM migration flag", + UR_RESULT_ERROR_INVALID_ENUMERATION); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + void *HIPDevicePtr = const_cast(pMem); // HIP_POINTER_ATTRIBUTE_RANGE_SIZE is not an attribute in ROCM < 5, // so we can't perform this check for such cases. @@ -1440,8 +1453,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( return UR_RESULT_SUCCESS; } - UR_CHECK_ERROR( - hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream)); + UR_CHECK_ERROR(hipMemPrefetchAsync(pMem, size, TargetDevice, HIPStream)); releaseEvent(); } catch (ur_result_t Err) { return Err; diff --git a/unified-runtime/source/adapters/level_zero/adapter.cpp b/unified-runtime/source/adapters/level_zero/adapter.cpp index 362a2479cdf47..9de4138f7e433 100644 --- a/unified-runtime/source/adapters/level_zero/adapter.cpp +++ b/unified-runtime/source/adapters/level_zero/adapter.cpp @@ -506,72 +506,71 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() bool forceLoadedAdapter = ur_getenv("UR_ADAPTERS_FORCE_LOAD").has_value(); if (!forceLoadedAdapter) { #ifdef UR_ADAPTER_LEVEL_ZERO_V2 - auto [useV2, reason] = shouldUseV2Adapter(); - if (!useV2) { - UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason); - return; - } + auto [useV2, reason] = shouldUseV2Adapter(); + if (!useV2) { + UR_LOG(INFO, "Skipping L0 V2 adapter: {}", reason); + return; + } #else - auto [useV1, reason] = shouldUseV1Adapter(); - if (!useV1) { - UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason); - return; - } + auto [useV1, reason] = shouldUseV1Adapter(); + if (!useV1) { + UR_LOG(INFO, "Skipping L0 V1 adapter: {}", reason); + return; + } #endif } - // Check if the user has enabled the default L0 SysMan initialization. - const int UrSysmanZesinitEnable = [&UserForcedSysManInit] { - const char *UrRet = std::getenv("UR_L0_ENABLE_ZESINIT_DEFAULT"); - if (!UrRet) - return 0; - UserForcedSysManInit &= 2; - return std::atoi(UrRet); - }(); - - bool ZesInitNeeded = UrSysmanZesinitEnable && !UrSysManEnvInitEnabled; - // Unless the user has forced the SysMan init, we will check the device - // version to see if the zesInit is needed. - if (UserForcedSysManInit == 0 && checkDeviceIntelGPUIpVersionOrNewer( - 0x05004000) == UR_RESULT_SUCCESS) { - if (UrSysManEnvInitEnabled) { - setEnvVar("ZES_ENABLE_SYSMAN", "0"); - } - ZesInitNeeded = true; - } - if (ZesInitNeeded) { + // Check if the user has enabled the default L0 SysMan initialization. + const int UrSysmanZesinitEnable = [&UserForcedSysManInit] { + const char *UrRet = std::getenv("UR_L0_ENABLE_ZESINIT_DEFAULT"); + if (!UrRet) + return 0; + UserForcedSysManInit &= 2; + return std::atoi(UrRet); + }(); + + bool ZesInitNeeded = UrSysmanZesinitEnable && !UrSysManEnvInitEnabled; + // Unless the user has forced the SysMan init, we will check the device + // version to see if the zesInit is needed. + if (UserForcedSysManInit == 0 && + checkDeviceIntelGPUIpVersionOrNewer(0x05004000) == UR_RESULT_SUCCESS) { + if (UrSysManEnvInitEnabled) { + setEnvVar("ZES_ENABLE_SYSMAN", "0"); + } + ZesInitNeeded = true; + } + if (ZesInitNeeded) { #ifdef UR_STATIC_LEVEL_ZERO - getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp; - getSysManDriversFunctionPtr = zesDriverGet; - sysManInitFunctionPtr = zesInit; + getDeviceByUUIdFunctionPtr = zesDriverGetDeviceByUuidExp; + getSysManDriversFunctionPtr = zesDriverGet; + sysManInitFunctionPtr = zesInit; #else - getDeviceByUUIdFunctionPtr = (zes_pfnDriverGetDeviceByUuidExp_t) - ur_loader::LibLoader::getFunctionPtr(processHandle, - "zesDriverGetDeviceByUuidExp"); - getSysManDriversFunctionPtr = - (zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr( - processHandle, "zesDriverGet"); - sysManInitFunctionPtr = - (zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr(processHandle, - "zesInit"); + getDeviceByUUIdFunctionPtr = + (zes_pfnDriverGetDeviceByUuidExp_t)ur_loader::LibLoader::getFunctionPtr( + processHandle, "zesDriverGetDeviceByUuidExp"); + getSysManDriversFunctionPtr = + (zes_pfnDriverGet_t)ur_loader::LibLoader::getFunctionPtr( + processHandle, "zesDriverGet"); + sysManInitFunctionPtr = (zes_pfnInit_t)ur_loader::LibLoader::getFunctionPtr( + processHandle, "zesInit"); #endif - } - if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr && - sysManInitFunctionPtr) { - ze_init_flags_t L0ZesInitFlags = 0; - UR_LOG(DEBUG, "\nzesInit with flags value of {}\n", - static_cast(L0ZesInitFlags)); - ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags)); - } else { - ZesResult = ZE_RESULT_ERROR_UNINITIALIZED; - } + } + if (getDeviceByUUIdFunctionPtr && getSysManDriversFunctionPtr && + sysManInitFunctionPtr) { + ze_init_flags_t L0ZesInitFlags = 0; + UR_LOG(DEBUG, "\nzesInit with flags value of {}\n", + static_cast(L0ZesInitFlags)); + ZesResult = ZE_CALL_NOCHECK(sysManInitFunctionPtr, (L0ZesInitFlags)); + } else { + ZesResult = ZE_RESULT_ERROR_UNINITIALIZED; + } - ur_result_t err = initPlatforms(this, platforms, ZesResult); - if (err == UR_RESULT_SUCCESS) { - Platforms = std::move(platforms); - } else { - throw err; - } + ur_result_t err = initPlatforms(this, platforms, ZesResult); + if (err == UR_RESULT_SUCCESS) { + Platforms = std::move(platforms); + } else { + throw err; + } } void globalAdapterOnDemandCleanup() { diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index 1e68069db51b2..687c905417d8b 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -1313,7 +1313,7 @@ ur_result_t urCommandBufferAppendMemBufferReadRectExp( ur_result_t urCommandBufferAppendUSMPrefetchExp( ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size, - ur_usm_migration_flags_t /*Flags*/, uint32_t NumSyncPointsInWaitList, + ur_usm_migration_flags_t Flags, uint32_t NumSyncPointsInWaitList, const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, uint32_t /*NumEventsInWaitList*/, const ur_event_handle_t * /*EventWaitList*/, @@ -1327,6 +1327,17 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp( UR_COMMAND_USM_PREFETCH, CommandBuffer, CommandBuffer->ZeComputeCommandList, NumSyncPointsInWaitList, SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent)); + switch (Flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + UR_LOG(WARN, "commandBufferAppendUSMPrefetch: L0 does not support prefetch " + "to host yet"); + break; + default: + UR_LOG(ERR, "commandBufferAppendUSMPrefetch: invalid USM migration flag"); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } if (!ZeEventList.empty()) { ZE2UR_CALL(zeCommandListAppendWaitOnEvents, @@ -1335,9 +1346,11 @@ ur_result_t urCommandBufferAppendUSMPrefetchExp( } // Add the prefetch command to the command-buffer. - // Note that L0 does not handle migration flags. - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, - (CommandBuffer->ZeComputeCommandList, Mem, Size)); + // TODO Support migration flags after L0 backend support is added. + if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, + (CommandBuffer->ZeComputeCommandList, Mem, Size)); + } if (!CommandBuffer->IsInOrderCmdList) { // Level Zero does not have a completion "event" with the prefetch API, diff --git a/unified-runtime/source/adapters/level_zero/memory.cpp b/unified-runtime/source/adapters/level_zero/memory.cpp index 3b1158645e77a..107fcc2d1c2f5 100644 --- a/unified-runtime/source/adapters/level_zero/memory.cpp +++ b/unified-runtime/source/adapters/level_zero/memory.cpp @@ -1265,7 +1265,7 @@ ur_result_t urEnqueueUSMPrefetch( /// [in] size in bytes to be fetched size_t Size, /// [in] USM prefetch flags - ur_usm_migration_flags_t /*Flags*/, + ur_usm_migration_flags_t Flags, /// [in] size of the event wait list uint32_t NumEventsInWaitList, /// [in][optional][range(0, numEventsInWaitList)] pointer to a list of @@ -1276,6 +1276,18 @@ ur_result_t urEnqueueUSMPrefetch( /// [in,out][optional] return an event object that identifies this /// particular command instance. ur_event_handle_t *OutEvent) { + switch (Flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + UR_LOG(WARN, + "enqueueUSMPrefetch: L0 does not support prefetch to host yet"); + break; + default: + UR_LOG(ERR, "enqueueUSMPrefetch: invalid USM migration flag"); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + // Lock automatically releases when this goes out of scope. std::scoped_lock lock(Queue->Mutex); @@ -1315,8 +1327,10 @@ ur_result_t urEnqueueUSMPrefetch( ZE2UR_CALL(zeCommandListAppendWaitOnEvents, (ZeCommandList, WaitList.Length, WaitList.ZeEventList)); } - // TODO: figure out how to translate "flags" - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, (ZeCommandList, Mem, Size)); + // TODO: Support migration flags after L0 backend support is added + if (Flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, (ZeCommandList, Mem, Size)); + } // TODO: Level Zero does not have a completion "event" with the prefetch API, // so manually add command to signal our event. diff --git a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp index 04e202265d05c..0fcd030c699e5 100644 --- a/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp @@ -284,11 +284,23 @@ ur_result_t ur_command_list_manager::appendUSMFill( } ur_result_t ur_command_list_manager::appendUSMPrefetch( - const void *pMem, size_t size, ur_usm_migration_flags_t /*flags*/, + const void *pMem, size_t size, ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t phEvent) { TRACK_SCOPE_LATENCY("ur_command_list_manager::appendUSMPrefetch"); + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + UR_LOG(WARN, + "appendUSMPrefetch: L0v2 does not support prefetch to host yet"); + break; + default: + UR_LOG(ERR, "appendUSMPrefetch: invalid USM migration flag"); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_USM_PREFETCH); auto [pWaitEvents, numWaitEvents] = getWaitListView(phEventWaitList, numEventsInWaitList); @@ -297,9 +309,11 @@ ur_result_t ur_command_list_manager::appendUSMPrefetch( ZE2UR_CALL(zeCommandListAppendWaitOnEvents, (zeCommandList.get(), numWaitEvents, pWaitEvents)); } - // TODO: figure out how to translate "flags" - ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, - (zeCommandList.get(), pMem, size)); + // TODO: Support migration flags after L0 backend support is added + if (flags == UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE) { + ZE2UR_CALL(zeCommandListAppendMemoryPrefetch, + (zeCommandList.get(), pMem, size)); + } if (zeSignalEvent) { ZE2UR_CALL(zeCommandListAppendSignalEvent, (zeCommandList.get(), zeSignalEvent)); diff --git a/unified-runtime/source/adapters/mock/ur_mockddi.cpp b/unified-runtime/source/adapters/mock/ur_mockddi.cpp index 74cb1accfa448..992ca2a16a385 100644 --- a/unified-runtime/source/adapters/mock/ur_mockddi.cpp +++ b/unified-runtime/source/adapters/mock/ur_mockddi.cpp @@ -10681,7 +10681,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/source/adapters/opencl/common.hpp b/unified-runtime/source/adapters/opencl/common.hpp index 0cfa916e49273..fc335186fce2a 100644 --- a/unified-runtime/source/adapters/opencl/common.hpp +++ b/unified-runtime/source/adapters/opencl/common.hpp @@ -187,6 +187,7 @@ CONSTFIX char CreateBufferWithPropertiesName[] = CONSTFIX char SetKernelArgMemPointerName[] = "clSetKernelArgMemPointerINTEL"; CONSTFIX char EnqueueMemFillName[] = "clEnqueueMemFillINTEL"; CONSTFIX char EnqueueMemcpyName[] = "clEnqueueMemcpyINTEL"; +CONSTFIX char EnqueueMigrateMemName[] = "clEnqueueMigrateMemINTEL"; CONSTFIX char GetMemAllocInfoName[] = "clGetMemAllocInfoINTEL"; CONSTFIX char SetProgramSpecializationConstantName[] = "clSetProgramSpecializationConstant"; diff --git a/unified-runtime/source/adapters/opencl/extension_functions.def b/unified-runtime/source/adapters/opencl/extension_functions.def index c7b4861807d98..47e85f918a222 100644 --- a/unified-runtime/source/adapters/opencl/extension_functions.def +++ b/unified-runtime/source/adapters/opencl/extension_functions.def @@ -8,6 +8,7 @@ CL_EXTENSION_FUNC(clMemBlockingFreeINTEL) CL_EXTENSION_FUNC(clSetKernelArgMemPointerINTEL) CL_EXTENSION_FUNC(clEnqueueMemFillINTEL) CL_EXTENSION_FUNC(clEnqueueMemcpyINTEL) +CL_EXTENSION_FUNC(clEnqueueMigrateMemINTEL) CL_EXTENSION_FUNC(clGetMemAllocInfoINTEL) CL_EXTENSION_FUNC(clEnqueueWriteGlobalVariable) CL_EXTENSION_FUNC(clEnqueueReadGlobalVariable) diff --git a/unified-runtime/source/adapters/opencl/usm.cpp b/unified-runtime/source/adapters/opencl/usm.cpp index e3c510c745766..09cf31aee8645 100644 --- a/unified-runtime/source/adapters/opencl/usm.cpp +++ b/unified-runtime/source/adapters/opencl/usm.cpp @@ -524,36 +524,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch( [[maybe_unused]] ur_usm_migration_flags_t flags, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - cl_event Event; + // TODO: Uncomment implementation when issues with impl are resolved. + + // cl_mem_migration_flags MigrationFlag; + switch (flags) { + case UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE: + // Note: currently opencl:cpu will break with this value, but opencl:gpu + // will work just fine. A spec change has been made to address this issue, + // and is waiting to be implemented: + // https://github.com/KhronosGroup/OpenCL-Docs/pull/1412/files#diff-7e4c12789cfc81c40637d32b7113b0cca2c3ee0beabaabb9acd9da743f7b5780R974 + + // MigrationFlag = 0; // OpenCL spec stipulates 0 as host + break; + case UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST: + // Note: there is currently no driver support for this. + + // MigrationFlag = CL_MIGRATE_MEM_OBJECT_HOST; + break; + default: + cl_adapter::setErrorMessage("Invalid USM migration flag", + UR_RESULT_ERROR_INVALID_ENUMERATION); + return UR_RESULT_ERROR_INVALID_ENUMERATION; + } + + /* + // Have to look up the context from the kernel + cl_context CLContext = hQueue->Context->CLContext; + + clEnqueueMigrateMemINTEL_fn EnqueueMigrateMem = nullptr; + UR_RETURN_ON_FAILURE( + cl_ext::getExtFuncFromContext( + CLContext, ur::cl::getAdapter()->fnCache.clEnqueueMigrateMemINTELCache, + cl_ext::EnqueueMigrateMemName, &EnqueueMigrateMem)); + */ + + cl_event Event = nullptr; std::vector CLWaitEvents(numEventsInWaitList); for (uint32_t i = 0; i < numEventsInWaitList; i++) { CLWaitEvents[i] = phEventWaitList[i]->CLEvent; } + + /* + CL_RETURN_ON_FAILURE(EnqueueMigrateMem( + hQueue->CLQueue, pMem, size, MigrationFlag, numEventsInWaitList, + CLWaitEvents.data(), ifUrEvent(phEvent, Event))); + */ + + // TODO: when issues with impl are fully resolved, delete this and use + // waitlisting from EnqueueMigrateMem instead. CL_RETURN_ON_FAILURE(clEnqueueMarkerWithWaitList( hQueue->CLQueue, numEventsInWaitList, CLWaitEvents.data(), ifUrEvent(phEvent, Event))); + UR_RETURN_ON_FAILURE(createUREvent(Event, hQueue->Context, hQueue, phEvent)); return UR_RESULT_SUCCESS; - /* - // Use this once impls support it. - // Have to look up the context from the kernel - cl_context CLContext = hQueue->Context; - - clEnqueueMigrateMemINTEL_fn FuncPtr; - ur_result_t Err = cl_ext::getExtFuncFromContext( - CLContext, "clEnqueueMigrateMemINTEL", &FuncPtr); - - ur_result_t RetVal; - if (Err != UR_RESULT_SUCCESS) { - RetVal = Err; - } else { - RetVal = map_cl_error_to_ur( - FuncPtr(hQueue->CLQueue, pMem, size, flags, - numEventsInWaitList, - reinterpret_cast(phEventWaitList), - reinterpret_cast(phEvent))); - } - */ } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMAdvise( diff --git a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp index 1cac607be8559..bfdfd6096470c 100644 --- a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp +++ b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp @@ -9039,7 +9039,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp index 9dd572ecd315a..d19d6c4e47010 100644 --- a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp +++ b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp @@ -9822,7 +9822,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/source/loader/ur_ldrddi.cpp b/unified-runtime/source/loader/ur_ldrddi.cpp index 0a09a3072cd48..c4f68dd426c90 100644 --- a/unified-runtime/source/loader/ur_ldrddi.cpp +++ b/unified-runtime/source/loader/ur_ldrddi.cpp @@ -5152,7 +5152,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/source/loader/ur_libapi.cpp b/unified-runtime/source/loader/ur_libapi.cpp index 59edc89920e92..4e207537cf2d7 100644 --- a/unified-runtime/source/loader/ur_libapi.cpp +++ b/unified-runtime/source/loader/ur_libapi.cpp @@ -9458,7 +9458,7 @@ ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/source/ur_api.cpp b/unified-runtime/source/ur_api.cpp index 771e27c3b8d6f..35f154f96cc3f 100644 --- a/unified-runtime/source/ur_api.cpp +++ b/unified-runtime/source/ur_api.cpp @@ -8239,7 +8239,7 @@ ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp( const void *pMemory, /// [in] size in bytes to be fetched. size_t size, - /// [in] USM prefetch flags + /// [in] USM migration flags ur_usm_migration_flags_t flags, /// [in] The number of sync points in the provided dependency list. uint32_t numSyncPointsInWaitList, diff --git a/unified-runtime/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp b/unified-runtime/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp index e0cb371ff09ac..88ef85cd93c4d 100644 --- a/unified-runtime/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp +++ b/unified-runtime/test/conformance/enqueue/urEnqueueUSMPrefetch.cpp @@ -20,7 +20,8 @@ struct urEnqueueUSMPrefetchWithParamTest UUR_DEVICE_TEST_SUITE_WITH_PARAM( urEnqueueUSMPrefetchWithParamTest, - ::testing::Values(UR_USM_MIGRATION_FLAG_DEFAULT), + ::testing::Values(UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, + UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST), uur::deviceTestWithParamPrinter); TEST_P(urEnqueueUSMPrefetchWithParamTest, Success) { @@ -102,14 +103,14 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE(urEnqueueUSMPrefetchTest); TEST_P(urEnqueueUSMPrefetchTest, InvalidNullHandleQueue) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE, urEnqueueUSMPrefetch(nullptr, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } TEST_P(urEnqueueUSMPrefetchTest, InvalidNullPointerMem) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER, urEnqueueUSMPrefetch(queue, nullptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } @@ -123,7 +124,7 @@ TEST_P(urEnqueueUSMPrefetchTest, InvalidEnumeration) { TEST_P(urEnqueueUSMPrefetchTest, InvalidSizeZero) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE, urEnqueueUSMPrefetch(queue, ptr, 0, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } @@ -132,14 +133,14 @@ TEST_P(urEnqueueUSMPrefetchTest, InvalidSizeTooLarge) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE, urEnqueueUSMPrefetch(queue, ptr, allocation_size * 2, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, nullptr)); } TEST_P(urEnqueueUSMPrefetchTest, InvalidEventWaitList) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST, urEnqueueUSMPrefetch(queue, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 1, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 1, nullptr, nullptr)); ur_event_handle_t validEvent; @@ -147,12 +148,12 @@ TEST_P(urEnqueueUSMPrefetchTest, InvalidEventWaitList) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST, urEnqueueUSMPrefetch(queue, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, &validEvent, nullptr)); ur_event_handle_t inv_evt = nullptr; ASSERT_EQ_RESULT(urEnqueueUSMPrefetch(queue, ptr, allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 1, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 1, &inv_evt, nullptr), UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST); diff --git a/unified-runtime/test/conformance/exp_command_buffer/commands.cpp b/unified-runtime/test/conformance/exp_command_buffer/commands.cpp index 22ac628c2726f..4e0ced3502943 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/commands.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/commands.cpp @@ -143,8 +143,21 @@ TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMPrefetchExp) { UUR_KNOWN_FAILURE_ON(uur::OpenCL{}); ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( - cmd_buf_handle, device_ptrs[0], allocation_size, 0, 0, nullptr, 0, - nullptr, nullptr, nullptr, nullptr)); + cmd_buf_handle, device_ptrs[0], allocation_size, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 0, nullptr, nullptr, + nullptr, nullptr)); +} + +TEST_P(urCommandBufferCommandsTest, + urCommandBufferAppendUSMPrefetchExpDeviceToHost) { + // No Prefetch command in cl_khr_command_buffer + // No driver support for prefetching from device to host on Intel GPUs + UUR_KNOWN_FAILURE_ON(uur::OpenCL{}, uur::LevelZero{}); + + ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( + cmd_buf_handle, device_ptrs[0], allocation_size, + UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST, 0, nullptr, 0, nullptr, nullptr, + nullptr, nullptr)); } TEST_P(urCommandBufferCommandsTest, urCommandBufferAppendUSMAdviseExp) { diff --git a/unified-runtime/test/conformance/exp_command_buffer/event_sync.cpp b/unified-runtime/test/conformance/exp_command_buffer/event_sync.cpp index ba592053876cd..26ec26b2a05bd 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/event_sync.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/event_sync.cpp @@ -426,9 +426,9 @@ TEST_P(CommandEventSyncTest, USMPrefetchExp) { // Test prefetch command waiting on queue event ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( - cmd_buf_handle, device_ptrs[1], allocation_size, 0 /* migration flags*/, - 0, nullptr, 1, &external_events[0], nullptr, &external_events[1], - nullptr)); + cmd_buf_handle, device_ptrs[1], allocation_size, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 1, &external_events[0], + nullptr, &external_events[1], nullptr)); ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle)); ASSERT_SUCCESS( urEnqueueCommandBufferExp(queue, cmd_buf_handle, 0, nullptr, nullptr)); diff --git a/unified-runtime/test/conformance/exp_command_buffer/in-order.cpp b/unified-runtime/test/conformance/exp_command_buffer/in-order.cpp index fd6335197cdf0..45357340ed6df 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/in-order.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/in-order.cpp @@ -101,7 +101,7 @@ struct urInOrderUSMCommandBufferExpTest : urInOrderCommandBufferExpTest { if (hints) { ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( in_order_cb, device_ptrs[0], allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, nullptr, 0, nullptr, nullptr, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 0, nullptr, nullptr, nullptr, nullptr)); } @@ -124,7 +124,7 @@ struct urInOrderUSMCommandBufferExpTest : urInOrderCommandBufferExpTest { if (hints) { ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( in_order_cb, device_ptrs[0], allocation_size, - UR_USM_MIGRATION_FLAG_DEFAULT, 0, nullptr, 0, nullptr, nullptr, + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 0, nullptr, nullptr, nullptr, nullptr)); } diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/event_sync.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/event_sync.cpp index fe0dc03728545..16763eaf0c15d 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/event_sync.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/event_sync.cpp @@ -723,8 +723,8 @@ TEST_P(CommandEventSyncUpdateTest, USMPrefetchExp) { // Test prefetch command waiting on queue event ASSERT_SUCCESS(urCommandBufferAppendUSMPrefetchExp( updatable_cmd_buf_handle, device_ptrs[1], allocation_size, - 0 /* migration flags*/, 0, nullptr, 1, &external_events[0], nullptr, - &external_events[1], &command_handles[0])); + UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE, 0, nullptr, 1, &external_events[0], + nullptr, &external_events[1], &command_handles[0])); ASSERT_NE(nullptr, command_handles[0]); ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); ASSERT_SUCCESS(urEnqueueCommandBufferExp(queue, updatable_cmd_buf_handle, 0,