From eb2f1bfe180e4257a9c41feebb80eb435b612eab Mon Sep 17 00:00:00 2001 From: Nicholas Wilson Date: Sun, 8 Jun 2025 16:13:51 +0800 Subject: [PATCH] [Dcompute] Initial support for Vulkan --- driver/dcomputecodegenerator.cpp | 8 ++ gen/abi/spirv.cpp | 9 ++ gen/abi/targets.h | 2 + gen/dcompute/target.h | 5 +- gen/dcompute/targetVulkan.cpp | 187 ++++++++++++++++++++++++++++ gen/irstate.cpp | 18 ++- gen/irstate.h | 2 +- runtime/druntime/src/ldc/dcompute.d | 1 + 8 files changed, 225 insertions(+), 7 deletions(-) create mode 100644 gen/dcompute/targetVulkan.cpp diff --git a/driver/dcomputecodegenerator.cpp b/driver/dcomputecodegenerator.cpp index 534757d4533..bba0dbd50ec 100644 --- a/driver/dcomputecodegenerator.cpp +++ b/driver/dcomputecodegenerator.cpp @@ -28,6 +28,14 @@ DComputeCodeGenManager::~DComputeCodeGenManager() {} DComputeTarget * DComputeCodeGenManager::createComputeTarget(const std::string &s) { + if (s.substr(0, 6) == "vulkan") { +#if LDC_LLVM_SUPPORTED_TARGET_SPIRV && LDC_LLVM_VER >= 2100 + //TODO version this for vulkan 1.3/1.4 + return createVulkanTarget(ctx, 0); +#else + error(Loc(), "LDC was not built with Vulkan DCompute support."); +#endif + } if (s.substr(0, 4) == "ocl-") { #if LDC_LLVM_SUPPORTED_TARGET_SPIRV #define OCL_VALID_VER_INIT 100, 110, 120, 200, 210, 220 diff --git a/gen/abi/spirv.cpp b/gen/abi/spirv.cpp index 60a3eada37d..2e95c598f47 100644 --- a/gen/abi/spirv.cpp +++ b/gen/abi/spirv.cpp @@ -53,3 +53,12 @@ struct SPIRVTargetABI : TargetABI { }; TargetABI *createSPIRVABI() { return new SPIRVTargetABI(); } + +struct SPIRVVulkanTargetABI : SPIRVTargetABI { + + llvm::CallingConv::ID callingConv(FuncDeclaration *fdecl) override { + // The synthesised wrapper is SPIR_KERNEL + return llvm::CallingConv::SPIR_FUNC; + } +}; +TargetABI *createSPIRVVulkanABI() { return new SPIRVVulkanTargetABI(); } diff --git a/gen/abi/targets.h b/gen/abi/targets.h index 49098fe2579..cfb3d25ea9d 100644 --- a/gen/abi/targets.h +++ b/gen/abi/targets.h @@ -31,6 +31,8 @@ TargetABI *getRISCV64TargetABI(); TargetABI *createSPIRVABI(); +TargetABI *createSPIRVVulkanABI(); + TargetABI *getWin64TargetABI(); TargetABI *getX86_64TargetABI(); diff --git a/gen/dcompute/target.h b/gen/dcompute/target.h index 6ffdbea7678..34a156b3607 100644 --- a/gen/dcompute/target.h +++ b/gen/dcompute/target.h @@ -27,7 +27,7 @@ class DComputeTarget { public: llvm::LLVMContext &ctx; int tversion; // OpenCL or CUDA CC version:major*100 + minor*10 - enum class ID { Host = 0, OpenCL = 1, CUDA = 2 }; + enum class ID { Host = 0, OpenCL = 1, CUDA = 2, Vulkan = 3 }; ID target; // ID for codegen time conditional compilation. const char *short_name; const char *binSuffix; @@ -60,4 +60,7 @@ DComputeTarget *createCUDATarget(llvm::LLVMContext &c, int sm); #if LDC_LLVM_SUPPORTED_TARGET_SPIRV DComputeTarget *createOCLTarget(llvm::LLVMContext &c, int oclver); +#if LDC_LLVM_VER >= 2100 +DComputeTarget *createVulkanTarget(llvm::LLVMContext &c, int ver); +#endif #endif diff --git a/gen/dcompute/targetVulkan.cpp b/gen/dcompute/targetVulkan.cpp new file mode 100644 index 00000000000..75bf74995b5 --- /dev/null +++ b/gen/dcompute/targetVulkan.cpp @@ -0,0 +1,187 @@ +//===-- gen/dcomputetargetOCL.cpp -----------------------------------------===// +// +// LDC – the LLVM D compiler +// +// Parts of this file are adapted from CodeGenFunction.cpp (Clang, LLVM). +// Therefore, this file is distributed under the LLVM license. +// See the LICENSE file for details. +//===----------------------------------------------------------------------===// + +#if LDC_LLVM_SUPPORTED_TARGET_SPIRV && LDC_LLVM_VER >= 2100 + +#include "dmd/id.h" +#include "dmd/identifier.h" +#include "dmd/template.h" +#include "dmd/mangle.h" +#include "dmd/module.h" +#include "gen/abi/targets.h" +#include "gen/dcompute/target.h" +#include "gen/dcompute/druntime.h" +#include "gen/logger.h" +#include "gen/optimizer.h" +#include "driver/targetmachine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Target/TargetMachine.h" +#include +#include + +using namespace dmd; + +namespace { +class TargetVulkan : public DComputeTarget { +public: + TargetVulkan(llvm::LLVMContext &c, int ver) + : DComputeTarget(c, ver, ID::Vulkan, "vulkan", "spv", createSPIRVVulkanABI(), + {{0, 1, 2, 3, 4}}) { + + _ir = new IRState("dcomputeTargetVulkan", ctx); + // "spirv-vulkan-foo"? foo = library, pixel, etc + std::string targTriple = "spirv1.6-unknown-vulkan1.3-compute"; + _ir->module.setTargetTriple(llvm::Triple(targTriple)); + + auto floatABI = ::FloatABI::Hard; + targetMachine = createTargetMachine( + targTriple, "spirv", "", {}, + ExplicitBitness::None, floatABI, + llvm::Reloc::Static, llvm::CodeModel::Medium, codeGenOptLevel(), false); + + _ir->module.setDataLayout(targetMachine->createDataLayout()); + + _ir->dcomputetarget = this; + } + + void addMetadata() override {} + + llvm::AttrBuilder buildKernAttrs(StructLiteralExp *kernAttr) { + auto b = llvm::AttrBuilder(ctx); + b.addAttribute("hlsl.shader", "compute"); + Expressions* elts = static_cast((*(kernAttr->elements))[0])->elements; + std::string numthreads = ""; + numthreads += std::to_string((*elts)[0]->toInteger()) + ","; + numthreads += std::to_string((*elts)[1]->toInteger()) + ","; + numthreads += std::to_string((*elts)[2]->toInteger()); + + b.addAttribute("hlsl.numthreads", numthreads); + // ? "hlsl.wavesize"="8,128,64" + // ? "hlsl.export" + return b; + } + llvm::Function *buildFunction(FuncDeclaration *fd) { + auto *void_func_void = llvm::FunctionType::get(llvm::Type::getVoidTy(ctx),{}, false); + auto linkage = llvm::GlobalValue::LinkageTypes::ExternalLinkage; + auto name = llvm::Twine(mangleExact(fd)) + llvm::Twine("_kernel"); + auto *f = llvm::Function::Create(void_func_void, linkage, name, _ir->module); + f->setCallingConv(llvm::CallingConv::SPIR_KERNEL); + return f; + } + llvm::Type *buildArgType(llvm::Function *llf, llvm::SmallVector &args, llvm::StringRef name) { + IF_LOG { + Logger::cout() << "buildArgType: " << *llf << std::endl; + } + llvm::FunctionType *tf = llf->getFunctionType(); + for (unsigned int i = 0; i < tf->getNumParams(); i++) { + llvm::Type *t = tf->getParamType(i); + if (t->isPointerTy()) + t = getI64Type(); // FIXME: 32 bit pointers on 32 but systems? + args[i] = t; + } + + IF_LOG { + for (auto *arg : args) { + Logger::cout() << *arg; + } + } + return llvm::StructType::create(ctx, args, name); + } + llvm::TargetExtType *buildTargetType(llvm::Type *argType) { + // TODO: Do we need to bother with a "spirv.Layout" here? + //auto *layout = llvm::TargetExtType::get(ctx, "spirv.Layout", ElemType,{}); + auto * ArrayType = llvm::ArrayType::get(argType, 0); + return llvm::TargetExtType::get(ctx, "spirv.VulkanBuffer", + {ArrayType}, + {12/*StorageClass*/, 0 /*isWritable*/}); + } + + llvm::Value *buildIntrinsicCall(IRBuilder<>& builder, llvm::StringRef dbg,llvm::StringRef name, + llvm::ArrayRef types, llvm::ArrayRef args) { + IF_LOG { + Logger::println("buildIntrinsicCall: %s", name.data()); + } + LOG_SCOPE + llvm::Function *intrinsic = llvm::Intrinsic::getOrInsertDeclaration(&_ir->module, + llvm::Intrinsic::lookupIntrinsicID(name), + types); + IF_LOG { + Logger::cout() << "intrinsic = " << *intrinsic << std::endl; + Logger::println("args:"); + LOG_SCOPE + for (auto* arg : args) { + Logger::cout() << *arg << std::endl; + } + } + + return builder.CreateCall(intrinsic->getFunctionType(), intrinsic, args, dbg); + } + + void addKernelMetadata(FuncDeclaration *fd, llvm::Function *llf, StructLiteralExp *kernAttr) override { + // Fake being HLSL + llvm::Function *f = buildFunction(fd); + f->addFnAttrs(buildKernAttrs(kernAttr)); + + llvm::SmallVector argTypes(llf->getFunctionType()->getNumParams()); + auto name = llvm::Twine(mangleExact(fd)) + llvm::Twine("_args"); + auto *argType = buildArgType(llf, argTypes, name.str()); + llvm::Type *targetType = buildTargetType(argType); + + auto bb = llvm::BasicBlock::Create(ctx, "", f); + llvm::IRBuilder<> builder(ctx); + builder.SetInsertPoint(bb); + + llvm::Value *i32zero = llvm::ConstantInt::get(getI32Type(), 0, false); + llvm::Value *i32one = llvm::ConstantInt::get(getI32Type(), 1, false); + llvm::Value *i1false = llvm::ConstantInt::get(llvm::Type::getInt1Ty(ctx), 0, false); + + // We can't use `DtoConstCString` here because it ends up in the wrong address space, So we use + // `getCachedStringLiteral` directly with an explicitly supplied addrspace of `0`. + // FIXME: call should have `notnull` attribute on pointer? + auto *handle = buildIntrinsicCall(builder, "handle","llvm.spv.resource.handlefrombinding", + {targetType}, + {i32zero, i32zero, i32one, i32zero, i1false, _ir->getCachedStringLiteral(name.str(), 0) }); + auto *p11 = llvm::PointerType::get(ctx, 11); + auto *pointer = buildIntrinsicCall(builder, "pointer", "llvm.spv.resource.getpointer", + {p11, targetType}, {handle, i32one}); + llvm::FunctionType *tf = llf->getFunctionType(); + IF_LOG { + Logger::cout() << "load pointer: " << *pointer << std::endl; + Logger::cout() << _ir->module.getDataLayout().getABITypeAlign(argType).value() << std::endl; + Logger::cout() << tf->getParamType(0)->getTypeID() << std::endl; + Logger::cout() << "done" << std::endl; + } + LOG_SCOPE + llvm::SmallVector args(tf->getNumParams()); + + auto *arg = builder.CreateAlignedLoad(argType, pointer, _ir->module.getDataLayout().getABITypeAlign(argType), false); + IF_LOG { + // Logger::cout() << "load elements from " << *arg << std::endl; + // Logger::cout() << "of type " << *argType << std::endl; + } + for (unsigned int i = 0; i < tf->getNumParams(); i++) { + args[i] = builder.CreateExtractValue(arg, {i}); + llvm::Type *t = tf->getParamType(i); + if (t->isPointerTy()) + args[i] = builder.CreateIntToPtr(args[i],t); + } + + builder.CreateCall(llf->getFunctionType(), llf, args); + builder.CreateRetVoid(); + IF_LOG Logger::cout() << *f << std::endl; + } + +}; +} // anonymous namespace. + +DComputeTarget *createVulkanTarget(llvm::LLVMContext &c, int ver) { + return new TargetVulkan(c, ver); +} + +#endif // LDC_LLVM_SUPPORTED_TARGET_SPIRV diff --git a/gen/irstate.cpp b/gen/irstate.cpp index 4a0d9fc35f4..76193cb3f82 100644 --- a/gen/irstate.cpp +++ b/gen/irstate.cpp @@ -191,7 +191,9 @@ template LLGlobalVariable * getCachedStringLiteralImpl(llvm::Module &module, llvm::StringMap &cache, - llvm::StringRef key, F initFactory) { + llvm::StringRef key, + std::optional< unsigned > addrspace, + F initFactory) { auto iter = cache.find(key); if (iter != cache.end()) { return iter->second; @@ -201,7 +203,8 @@ getCachedStringLiteralImpl(llvm::Module &module, auto gvar = new LLGlobalVariable(module, constant->getType(), true, - LLGlobalValue::PrivateLinkage, constant, ".str"); + LLGlobalValue::PrivateLinkage, constant, + ".str", nullptr, llvm::GlobalValue::NotThreadLocal, addrspace); gvar->setUnnamedAddr(LLGlobalValue::UnnamedAddr::Global); cache[key] = gvar; @@ -241,14 +244,19 @@ LLGlobalVariable *IRState::getCachedStringLiteral(StringExp *se) { const llvm::StringRef key(reinterpret_cast(keyData.ptr), keyData.length); - return getCachedStringLiteralImpl(module, *cache, key, [se]() { + return getCachedStringLiteralImpl(module, *cache, key, std::nullopt, [se]() { // null-terminate return buildStringLiteralConstant(se, se->len + 1); }); } -LLGlobalVariable *IRState::getCachedStringLiteral(llvm::StringRef s) { - return getCachedStringLiteralImpl(module, cachedStringLiterals, s, [&]() { +LLGlobalVariable *IRState::getCachedStringLiteral(llvm::StringRef s, + std::optional< unsigned > addrspace) { + return getCachedStringLiteralImpl(module, + cachedStringLiterals, + s, + addrspace, + [&]() { return llvm::ConstantDataArray::getString(context(), s, true); }); } diff --git a/gen/irstate.h b/gen/irstate.h index 39cf7c7a5dd..75b8f11abab 100644 --- a/gen/irstate.h +++ b/gen/irstate.h @@ -248,7 +248,7 @@ struct IRState { // calls with a StringExp with matching data will return the same variable. // Exception: ulong[]-typed hex strings (not null-terminated either). llvm::GlobalVariable *getCachedStringLiteral(StringExp *se); - llvm::GlobalVariable *getCachedStringLiteral(llvm::StringRef s); + llvm::GlobalVariable *getCachedStringLiteral(llvm::StringRef s, std::optional< unsigned > = std::nullopt); // List of functions with cpu or features attributes overriden by user std::vector targetCpuOrFeaturesOverridden; diff --git a/runtime/druntime/src/ldc/dcompute.d b/runtime/druntime/src/ldc/dcompute.d index 396a0a47b01..e91bcf5c5c3 100644 --- a/runtime/druntime/src/ldc/dcompute.d +++ b/runtime/druntime/src/ldc/dcompute.d @@ -14,6 +14,7 @@ enum ReflectTarget : uint Host = 0, OpenCL = 1, CUDA = 2, + Vulkan = 3, } /** * The pseudo conditional compilation function.