From e7dcf6d1c1a0c4c09af006be4eb65a5714897a02 Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Mon, 9 Jun 2025 15:55:05 +0800 Subject: [PATCH 1/9] update sycl extension for muladd op --- extension_cpp/csrc/sycl/muladd.sycl | 189 ++++++++++++++++++++++++++++ setup.py | 114 ++++++++++++----- test/test_extension.py | 23 +++- 3 files changed, 291 insertions(+), 35 deletions(-) create mode 100644 extension_cpp/csrc/sycl/muladd.sycl diff --git a/extension_cpp/csrc/sycl/muladd.sycl b/extension_cpp/csrc/sycl/muladd.sycl new file mode 100644 index 0000000..d494728 --- /dev/null +++ b/extension_cpp/csrc/sycl/muladd.sycl @@ -0,0 +1,189 @@ +#include +#include +#include +#include +#include + +namespace extension_cpp { + + +// MulAdd Kernel: result = a * b + c +static void muladd_kernel( + int numel, const float* a, const float* b, float c, float* result, + const sycl::nd_item<1>& item) { + int idx = item.get_global_id(0); + if (idx < numel) { + result[idx] = a[idx] * b[idx] + c; + } +} + +// Mul Kernel: result = a * b +static void mul_kernel( + int numel, const float* a, const float* b, float* result, + const sycl::nd_item<1>& item) { + int idx = item.get_global_id(0); + if (idx < numel) { + result[idx] = a[idx] * b[idx]; + } +} + +// Add Kernel: result = a + b +static void add_kernel( + int numel, const float* a, const float* b, float* result, + const sycl::nd_item<1>& item) { + int idx = item.get_global_id(0); + if (idx < numel) { + result[idx] = a[idx] + b[idx]; + } +} + + +class MulAddKernelFunctor { +public: + MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result) + : numel(_numel), a(_a), b(_b), c(_c), result(_result) {} + + void operator()(const sycl::nd_item<1>& item) const { + muladd_kernel(numel, a, b, c, result, item); + } + +private: + int numel; + const float* a; + const float* b; + float c; + float* result; +}; + +class MulKernelFunctor { +public: + MulKernelFunctor(int _numel, const float* _a, const float* _b, float* _result) + : numel(_numel), a(_a), b(_b), result(_result) {} + + void operator()(const sycl::nd_item<1>& item) const { + mul_kernel(numel, a, b, result, item); + } + +private: + int numel; + const float* a; + const float* b; + float* result; +}; + +class AddKernelFunctor { +public: + AddKernelFunctor(int _numel, const float* _a, const float* _b, float* _result) + : numel(_numel), a(_a), b(_b), result(_result) {} + + void operator()(const sycl::nd_item<1>& item) const { + add_kernel(numel, a, b, result, item); + } + +private: + int numel; + const float* a; + const float* b; + float* result; +}; + + +at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); + TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); + TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); + TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); + TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); + + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = at::empty_like(a_contig); + + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* res_ptr = result.data_ptr(); + int numel = a_contig.numel(); + + sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); + constexpr int threads = 256; + int blocks = (numel + threads - 1) / threads; + + queue.submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(blocks * threads, threads), + MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast(c), res_ptr) + ); + }); + return result; +} + +at::Tensor mymul_xpu(const at::Tensor& a, const at::Tensor& b) { + TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); + TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); + TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); + TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); + TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); + + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = at::empty_like(a_contig); + + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* res_ptr = result.data_ptr(); + int numel = a_contig.numel(); + + sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); + constexpr int threads = 256; + int blocks = (numel + threads - 1) / threads; + + queue.submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(blocks * threads, threads), + MulKernelFunctor(numel, a_ptr, b_ptr, res_ptr) + ); + }); + return result; +} + +void myadd_out_xpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) { + TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); + TORCH_CHECK(b.sizes() == out.sizes(), "b and out must have the same shape"); + TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); + TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); + TORCH_CHECK(out.is_contiguous(), "out must be contiguous"); + TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); + TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); + TORCH_CHECK(out.device().is_xpu(), "out must be an XPU tensor"); + + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* out_ptr = out.data_ptr(); + int numel = a_contig.numel(); + + sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); + constexpr int threads = 256; + int blocks = (numel + threads - 1) / threads; + + queue.submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(blocks * threads, threads), + AddKernelFunctor(numel, a_ptr, b_ptr, out_ptr) + ); + }); +} + +// ================================================== +// Register Sycl Implementations to Torch Library +// ================================================== + +TORCH_LIBRARY_IMPL(extension_cpp, XPU, m) { + m.impl("mymuladd", mymuladd_xpu); + m.impl("mymul", mymul_xpu); + m.impl("myadd_out", myadd_out_xpu); +} + +} // namespace extension_cpp diff --git a/setup.py b/setup.py index 0dde1e4..db7b84b 100644 --- a/setup.py +++ b/setup.py @@ -2,63 +2,119 @@ # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - import os import torch import glob - from setuptools import find_packages, setup - from torch.utils.cpp_extension import ( CppExtension, CUDAExtension, BuildExtension, CUDA_HOME, ) +# Conditional import for SyclExtension +try: + from torch.utils.cpp_extension import SyclExtension +except ImportError: + SyclExtension = None library_name = "extension_cpp" +# Configure Py_LIMITED_API based on PyTorch version if torch.__version__ >= "2.6.0": - py_limited_api = True + py_limited_api = False else: py_limited_api = False - def get_extensions(): debug_mode = os.getenv("DEBUG", "0") == "1" - use_cuda = os.getenv("USE_CUDA", "1") == "1" - if debug_mode: - print("Compiling in debug mode") - use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None - extension = CUDAExtension if use_cuda else CppExtension + # Determine backend (CUDA, SYCL, or C++) + use_cuda = os.getenv("USE_CUDA", "auto") + use_sycl = os.getenv("USE_SYCL", "auto") + + # Auto-detect CUDA + if use_cuda == "auto": + use_cuda = torch.cuda.is_available() and CUDA_HOME is not None + else: + use_cuda = use_cuda.lower() == "true" or use_cuda == "1" + + # Auto-detect SYCL + if use_sycl == "auto": + use_sycl = SyclExtension is not None and torch.xpu.is_available() + else: + use_sycl = use_sycl.lower() == "true" or use_sycl == "1" + if use_cuda and use_sycl: + raise RuntimeError("Cannot enable both CUDA and SYCL backends simultaneously.") + + print("use cuda & use sycl",use_cuda, use_sycl) + + extension = None + if use_cuda: + extension = CUDAExtension + print("Building with CUDA backend") + elif use_sycl and SyclExtension is not None: + extension = SyclExtension + print("Building with SYCL backend") + else: + extension = CppExtension + print("Building with C++ backend") + + # Compilation arguments extra_link_args = [] - extra_compile_args = { - "cxx": [ - "-O3" if not debug_mode else "-O0", - "-fdiagnostics-color=always", - "-DPy_LIMITED_API=0x03090000", # min CPython version 3.9 - ], - "nvcc": [ + extra_compile_args = {"cxx": []} + if extension == CUDAExtension: + extra_compile_args = { + "cxx": ["-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + "-DPy_LIMITED_API=0x03090000"], + "nvcc": ["-O3" if not debug_mode else "-O0"] + } + elif extension == SyclExtension: + print("SYCLExtension branch, set extra_compile_args") + extra_compile_args = { + "cxx": ["-O3" if not debug_mode else "-O0"], + "sycl": ["-O3" if not debug_mode else "-O0"] + } + # extra_compile_args = { + # "cxx": ["-O3" if not debug_mode else "-O0", + # "-fdiagnostics-color=always", + # "-DPy_LIMITED_API=0x03090000"], + # "sycl": ["-O3" if not debug_mode else "-O0"] + # } + else: + extra_compile_args["cxx"] = [ "-O3" if not debug_mode else "-O0", - ], - } + "-DPy_LIMITED_API=0x03090000"] + if debug_mode: extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") - extra_link_args.extend(["-O0", "-g"]) - + if extension == CUDAExtension: + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + elif extension == SYCLExtension: + extra_compile_args["sycl"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + # Source files collection this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, library_name, "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) - extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + backend_sources = [] + if extension == CUDAExtension: + backend_dir = os.path.join(extensions_dir, "cuda") + backend_sources = glob.glob(os.path.join(backend_dir, "*.cu")) + elif extension == SyclExtension: + backend_dir = os.path.join(extensions_dir, "sycl") + backend_sources = glob.glob(os.path.join(backend_dir, "*.sycl")) - if use_cuda: - sources += cuda_sources + sources += backend_sources + print("sources",sources) + print(len(sources)) + # Construct extension ext_modules = [ extension( f"{library_name}._C", @@ -71,17 +127,13 @@ def get_extensions(): return ext_modules - setup( name=library_name, version="0.0.1", packages=find_packages(), ext_modules=get_extensions(), install_requires=["torch"], - description="Example of PyTorch C++ and CUDA extensions", - long_description=open("README.md").read(), - long_description_content_type="text/markdown", - url="https://github.com/pytorch/extension-cpp", + description="Hybrid PyTorch extension supporting CUDA/SYCL/C++", cmdclass={"build_ext": BuildExtension}, options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, ) diff --git a/test/test_extension.py b/test/test_extension.py index 618f00b..8cc38a4 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -37,10 +37,14 @@ def _test_correctness(self, device): def test_correctness_cpu(self): self._test_correctness("cpu") - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") def test_correctness_cuda(self): self._test_correctness("cuda") + @unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU") + def test_correctness_xpu(self): + self._test_correctness("xpu") + def _test_gradients(self, device): samples = self.sample_inputs(device, requires_grad=True) for args in samples: @@ -57,10 +61,14 @@ def _test_gradients(self, device): def test_gradients_cpu(self): self._test_gradients("cpu") - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") def test_gradients_cuda(self): self._test_gradients("cuda") + @unittest.skipIf(not torch.xpu.is_available(), "requires Intel GPU") + def test_gradients_xpu(self): + self._test_gradients("xpu") + def _opcheck(self, device): # Use opcheck to check for incorrect usage of operator registration APIs samples = self.sample_inputs(device, requires_grad=True) @@ -71,10 +79,13 @@ def _opcheck(self, device): def test_opcheck_cpu(self): self._opcheck("cpu") - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") def test_opcheck_cuda(self): self._opcheck("cuda") + @unittest.skipIf(not torch.xpu.is_available(), "requires xpu") + def test_opcheck_xpu(self): + self._opcheck("xpu") class TestMyAddOut(TestCase): def sample_inputs(self, device, *, requires_grad=False): @@ -107,10 +118,14 @@ def _opcheck(self, device): def test_opcheck_cpu(self): self._opcheck("cpu") - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") def test_opcheck_cuda(self): self._opcheck("cuda") + @unittest.skipIf(not torch.xpu.is_available(), "requires xpu") + def test_opcheck_xpu(self): + self._opcheck("xpu") + if __name__ == "__main__": unittest.main() From e77e626723ffd01a35fbc8a893000e605e2afc5b Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Mon, 9 Jun 2025 16:00:54 +0800 Subject: [PATCH 2/9] update setup for cuda/sycl extension --- setup.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index db7b84b..c48865a 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ # Configure Py_LIMITED_API based on PyTorch version if torch.__version__ >= "2.6.0": - py_limited_api = False + py_limited_api = True else: py_limited_api = False @@ -74,15 +74,11 @@ def get_extensions(): elif extension == SyclExtension: print("SYCLExtension branch, set extra_compile_args") extra_compile_args = { - "cxx": ["-O3" if not debug_mode else "-O0"], + "cxx": ["-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + "-DPy_LIMITED_API=0x03090000"], "sycl": ["-O3" if not debug_mode else "-O0"] } - # extra_compile_args = { - # "cxx": ["-O3" if not debug_mode else "-O0", - # "-fdiagnostics-color=always", - # "-DPy_LIMITED_API=0x03090000"], - # "sycl": ["-O3" if not debug_mode else "-O0"] - # } else: extra_compile_args["cxx"] = [ "-O3" if not debug_mode else "-O0", @@ -93,7 +89,7 @@ def get_extensions(): if extension == CUDAExtension: extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) - elif extension == SYCLExtension: + elif extension == SyclExtension: extra_compile_args["sycl"].append("-g") extra_link_args.extend(["-O0", "-g"]) From f9a38a7ae2d23d61fc8b3f7b06e98d983bd4753d Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Mon, 9 Jun 2025 16:11:19 +0800 Subject: [PATCH 3/9] update setup and README for SyclExtension examples --- README.md | 6 +++--- setup.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d523814..a4134c4 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ # C++/CUDA Extensions in PyTorch -An example of writing a C++/CUDA extension for PyTorch. See +An example of writing a C++/CUDA/Sycl extension for PyTorch. See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial. This repo demonstrates how to write an example `extension_cpp.ops.mymuladd` -custom op that has both custom CPU and CUDA kernels. +custom op that has both custom CPU and CUDA/Sycl kernels. -The examples in this repo work with PyTorch 2.4+. +The examples in this repo work with PyTorch 2.4 or later for C++/CUDA & PyTorch 2.8 or later for Sycl. To build: ``` diff --git a/setup.py b/setup.py index c48865a..0a78112 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ def get_extensions(): extra_link_args = [] extra_compile_args = {"cxx": []} if extension == CUDAExtension: + print("CUDA is available, compile using CUDAExtension") extra_compile_args = { "cxx": ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", @@ -72,7 +73,7 @@ def get_extensions(): "nvcc": ["-O3" if not debug_mode else "-O0"] } elif extension == SyclExtension: - print("SYCLExtension branch, set extra_compile_args") + print("XPU is available, compile using SyclExtension") extra_compile_args = { "cxx": ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", From ebaeb319bf35685a9f65daf3f69b4644513e840a Mon Sep 17 00:00:00 2001 From: ZhaoqiongZ <106125927+ZhaoqiongZ@users.noreply.github.com> Date: Tue, 10 Jun 2025 09:38:43 +0800 Subject: [PATCH 4/9] Update extension_cpp/csrc/sycl/muladd.sycl Co-authored-by: Dmitry Rogozhkin --- extension_cpp/csrc/sycl/muladd.sycl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extension_cpp/csrc/sycl/muladd.sycl b/extension_cpp/csrc/sycl/muladd.sycl index d494728..7300f88 100644 --- a/extension_cpp/csrc/sycl/muladd.sycl +++ b/extension_cpp/csrc/sycl/muladd.sycl @@ -1,3 +1,5 @@ +// Copyright (c) 2025 Intel Corporation + #include #include #include From 1e186ec9bc034fd6207ee794b2545af91783975b Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Tue, 10 Jun 2025 10:53:20 +0800 Subject: [PATCH 5/9] add back long description --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0a78112..63f21cb 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,10 @@ def get_extensions(): packages=find_packages(), ext_modules=get_extensions(), install_requires=["torch"], - description="Hybrid PyTorch extension supporting CUDA/SYCL/C++", + description="Example of PyTorch C++ and CUDA/Sycl extensions", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/pytorch/extension-cpp", cmdclass={"build_ext": BuildExtension}, options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, ) From 156bb4bfd81011fc4e1ffb01d1e2ba85121de276 Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Tue, 10 Jun 2025 10:54:25 +0800 Subject: [PATCH 6/9] remove debug printout --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index 63f21cb..552d9b2 100644 --- a/setup.py +++ b/setup.py @@ -109,8 +109,6 @@ def get_extensions(): sources += backend_sources - print("sources",sources) - print(len(sources)) # Construct extension ext_modules = [ extension( From 0dd3d29f1f21f8f5dc9363b869e1aa577ca3b661 Mon Sep 17 00:00:00 2001 From: "Zheng, Zhaoqiong" Date: Tue, 10 Jun 2025 11:02:54 +0800 Subject: [PATCH 7/9] update comments for py_limited_api compatibility --- setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 552d9b2..5e83f9f 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,11 @@ library_name = "extension_cpp" -# Configure Py_LIMITED_API based on PyTorch version +# NOTE: PyTorch versions < 2.6 use torch.extension.h which depends on pybind11, +# and pybind11 requires full access to Python's C API (including internal +# structures like PyObject). This makes it incompatible with Py_LIMITED_API +# which restricts access to only stable Python C API symbols. +# For Py_LIMITED_API compatibility, use torch.library.h instead (PyTorch 2.6+). if torch.__version__ >= "2.6.0": py_limited_api = True else: From a3d54f3243d401390f76fdd77f7244605b5bf138 Mon Sep 17 00:00:00 2001 From: ZhaoqiongZ <106125927+ZhaoqiongZ@users.noreply.github.com> Date: Fri, 13 Jun 2025 09:15:41 +0800 Subject: [PATCH 8/9] Update README.md update title with SYCL --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a4134c4..b8b5701 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# C++/CUDA Extensions in PyTorch +# C++/CUDA/SYCL Extensions in PyTorch An example of writing a C++/CUDA/Sycl extension for PyTorch. See [here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial. From 39c442385acf00bddf831369b3cbdc373ee4ed7e Mon Sep 17 00:00:00 2001 From: ZhaoqiongZ <106125927+ZhaoqiongZ@users.noreply.github.com> Date: Fri, 13 Jun 2025 09:55:07 +0800 Subject: [PATCH 9/9] Update README.md with sycl description --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index b8b5701..121ce44 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,11 @@ An example of writing a C++/CUDA/Sycl extension for PyTorch. See This repo demonstrates how to write an example `extension_cpp.ops.mymuladd` custom op that has both custom CPU and CUDA/Sycl kernels. + +> **Note:** + `SYCL` serves as the backend programming language for Intel GPUs (device label `xpu`). For configuration details, see: + [Getting Started on Intel GPUs](https://docs.pytorch.org/docs/main/notes/get_start_xpu.html). + The examples in this repo work with PyTorch 2.4 or later for C++/CUDA & PyTorch 2.8 or later for Sycl. To build: