Skip to content

[Dcompute] Initial support for Vulkan #4958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions driver/dcomputecodegenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ DComputeCodeGenManager::~DComputeCodeGenManager() {}

DComputeTarget *
DComputeCodeGenManager::createComputeTarget(const std::string &s) {
if (s.substr(0, 6) == "vulkan") {
#if LDC_LLVM_SUPPORTED_TARGET_SPIRV && LDC_LLVM_VER >= 2100
//TODO version this for vulkan 1.3/1.4
return createVulkanTarget(ctx, 0);
#else
error(Loc(), "LDC was not built with Vulkan DCompute support.");
#endif
}
if (s.substr(0, 4) == "ocl-") {
#if LDC_LLVM_SUPPORTED_TARGET_SPIRV
#define OCL_VALID_VER_INIT 100, 110, 120, 200, 210, 220
Expand Down
9 changes: 9 additions & 0 deletions gen/abi/spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ struct SPIRVTargetABI : TargetABI {
};

TargetABI *createSPIRVABI() { return new SPIRVTargetABI(); }

struct SPIRVVulkanTargetABI : SPIRVTargetABI {

llvm::CallingConv::ID callingConv(FuncDeclaration *fdecl) override {
// The synthesised wrapper is SPIR_KERNEL
return llvm::CallingConv::SPIR_FUNC;
}
};
TargetABI *createSPIRVVulkanABI() { return new SPIRVVulkanTargetABI(); }
2 changes: 2 additions & 0 deletions gen/abi/targets.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ TargetABI *getRISCV64TargetABI();

TargetABI *createSPIRVABI();

TargetABI *createSPIRVVulkanABI();

TargetABI *getWin64TargetABI();

TargetABI *getX86_64TargetABI();
Expand Down
5 changes: 4 additions & 1 deletion gen/dcompute/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class DComputeTarget {
public:
llvm::LLVMContext &ctx;
int tversion; // OpenCL or CUDA CC version:major*100 + minor*10
enum class ID { Host = 0, OpenCL = 1, CUDA = 2 };
enum class ID { Host = 0, OpenCL = 1, CUDA = 2, Vulkan = 3 };
ID target; // ID for codegen time conditional compilation.
const char *short_name;
const char *binSuffix;
Expand Down Expand Up @@ -60,4 +60,7 @@ DComputeTarget *createCUDATarget(llvm::LLVMContext &c, int sm);

#if LDC_LLVM_SUPPORTED_TARGET_SPIRV
DComputeTarget *createOCLTarget(llvm::LLVMContext &c, int oclver);
#if LDC_LLVM_VER >= 2100
DComputeTarget *createVulkanTarget(llvm::LLVMContext &c, int ver);
#endif
#endif
187 changes: 187 additions & 0 deletions gen/dcompute/targetVulkan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//===-- gen/dcomputetargetOCL.cpp -----------------------------------------===//
//
// LDC – the LLVM D compiler
//
// Parts of this file are adapted from CodeGenFunction.cpp (Clang, LLVM).
// Therefore, this file is distributed under the LLVM license.
// See the LICENSE file for details.
//===----------------------------------------------------------------------===//

#if LDC_LLVM_SUPPORTED_TARGET_SPIRV && LDC_LLVM_VER >= 2100

#include "dmd/id.h"
#include "dmd/identifier.h"
#include "dmd/template.h"
#include "dmd/mangle.h"
#include "dmd/module.h"
#include "gen/abi/targets.h"
#include "gen/dcompute/target.h"
#include "gen/dcompute/druntime.h"
#include "gen/logger.h"
#include "gen/optimizer.h"
#include "driver/targetmachine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Target/TargetMachine.h"
#include <cstring>
#include <string>

using namespace dmd;

namespace {
class TargetVulkan : public DComputeTarget {
public:
TargetVulkan(llvm::LLVMContext &c, int ver)
: DComputeTarget(c, ver, ID::Vulkan, "vulkan", "spv", createSPIRVVulkanABI(),
{{0, 1, 2, 3, 4}}) {

_ir = new IRState("dcomputeTargetVulkan", ctx);
// "spirv-vulkan-foo"? foo = library, pixel, etc
std::string targTriple = "spirv1.6-unknown-vulkan1.3-compute";
_ir->module.setTargetTriple(llvm::Triple(targTriple));

auto floatABI = ::FloatABI::Hard;
targetMachine = createTargetMachine(
targTriple, "spirv", "", {},
ExplicitBitness::None, floatABI,
llvm::Reloc::Static, llvm::CodeModel::Medium, codeGenOptLevel(), false);

_ir->module.setDataLayout(targetMachine->createDataLayout());

_ir->dcomputetarget = this;
}

void addMetadata() override {}

llvm::AttrBuilder buildKernAttrs(StructLiteralExp *kernAttr) {
auto b = llvm::AttrBuilder(ctx);
b.addAttribute("hlsl.shader", "compute");
Expressions* elts = static_cast<ArrayLiteralExp*>((*(kernAttr->elements))[0])->elements;
std::string numthreads = "";
numthreads += std::to_string((*elts)[0]->toInteger()) + ",";
numthreads += std::to_string((*elts)[1]->toInteger()) + ",";
numthreads += std::to_string((*elts)[2]->toInteger());

b.addAttribute("hlsl.numthreads", numthreads);
// ? "hlsl.wavesize"="8,128,64"
// ? "hlsl.export"
return b;
}
llvm::Function *buildFunction(FuncDeclaration *fd) {
auto *void_func_void = llvm::FunctionType::get(llvm::Type::getVoidTy(ctx),{}, false);
auto linkage = llvm::GlobalValue::LinkageTypes::ExternalLinkage;
auto name = llvm::Twine(mangleExact(fd)) + llvm::Twine("_kernel");
auto *f = llvm::Function::Create(void_func_void, linkage, name, _ir->module);
f->setCallingConv(llvm::CallingConv::SPIR_KERNEL);
return f;
}
llvm::Type *buildArgType(llvm::Function *llf, llvm::SmallVector<llvm::Type *, 8> &args, llvm::StringRef name) {
IF_LOG {
Logger::cout() << "buildArgType: " << *llf << std::endl;
}
llvm::FunctionType *tf = llf->getFunctionType();
for (unsigned int i = 0; i < tf->getNumParams(); i++) {
llvm::Type *t = tf->getParamType(i);
if (t->isPointerTy())
t = getI64Type(); // FIXME: 32 bit pointers on 32 but systems?
args[i] = t;
}

IF_LOG {
for (auto *arg : args) {
Logger::cout() << *arg;
}
}
return llvm::StructType::create(ctx, args, name);
}
llvm::TargetExtType *buildTargetType(llvm::Type *argType) {
// TODO: Do we need to bother with a "spirv.Layout" here?
//auto *layout = llvm::TargetExtType::get(ctx, "spirv.Layout", ElemType,{});
auto * ArrayType = llvm::ArrayType::get(argType, 0);
return llvm::TargetExtType::get(ctx, "spirv.VulkanBuffer",
{ArrayType},
{12/*StorageClass*/, 0 /*isWritable*/});
}

llvm::Value *buildIntrinsicCall(IRBuilder<>& builder, llvm::StringRef dbg,llvm::StringRef name,
llvm::ArrayRef<llvm::Type *> types, llvm::ArrayRef<llvm::Value *> args) {
IF_LOG {
Logger::println("buildIntrinsicCall: %s", name.data());
}
LOG_SCOPE
llvm::Function *intrinsic = llvm::Intrinsic::getOrInsertDeclaration(&_ir->module,
llvm::Intrinsic::lookupIntrinsicID(name),
types);
IF_LOG {
Logger::cout() << "intrinsic = " << *intrinsic << std::endl;
Logger::println("args:");
LOG_SCOPE
for (auto* arg : args) {
Logger::cout() << *arg << std::endl;
}
}

return builder.CreateCall(intrinsic->getFunctionType(), intrinsic, args, dbg);
}

void addKernelMetadata(FuncDeclaration *fd, llvm::Function *llf, StructLiteralExp *kernAttr) override {
// Fake being HLSL
llvm::Function *f = buildFunction(fd);
f->addFnAttrs(buildKernAttrs(kernAttr));

llvm::SmallVector<llvm::Type *, 8> argTypes(llf->getFunctionType()->getNumParams());
auto name = llvm::Twine(mangleExact(fd)) + llvm::Twine("_args");
auto *argType = buildArgType(llf, argTypes, name.str());
llvm::Type *targetType = buildTargetType(argType);

auto bb = llvm::BasicBlock::Create(ctx, "", f);
llvm::IRBuilder<> builder(ctx);
builder.SetInsertPoint(bb);

llvm::Value *i32zero = llvm::ConstantInt::get(getI32Type(), 0, false);
llvm::Value *i32one = llvm::ConstantInt::get(getI32Type(), 1, false);
llvm::Value *i1false = llvm::ConstantInt::get(llvm::Type::getInt1Ty(ctx), 0, false);

// We can't use `DtoConstCString` here because it ends up in the wrong address space, So we use
// `getCachedStringLiteral` directly with an explicitly supplied addrspace of `0`.
// FIXME: call should have `notnull` attribute on pointer?
auto *handle = buildIntrinsicCall(builder, "handle","llvm.spv.resource.handlefrombinding",
{targetType},
{i32zero, i32zero, i32one, i32zero, i1false, _ir->getCachedStringLiteral(name.str(), 0) });
auto *p11 = llvm::PointerType::get(ctx, 11);
auto *pointer = buildIntrinsicCall(builder, "pointer", "llvm.spv.resource.getpointer",
{p11, targetType}, {handle, i32one});
llvm::FunctionType *tf = llf->getFunctionType();
IF_LOG {
Logger::cout() << "load pointer: " << *pointer << std::endl;
Logger::cout() << _ir->module.getDataLayout().getABITypeAlign(argType).value() << std::endl;
Logger::cout() << tf->getParamType(0)->getTypeID() << std::endl;
Logger::cout() << "done" << std::endl;
}
LOG_SCOPE
llvm::SmallVector<llvm::Value *, 8> args(tf->getNumParams());

auto *arg = builder.CreateAlignedLoad(argType, pointer, _ir->module.getDataLayout().getABITypeAlign(argType), false);
IF_LOG {
// Logger::cout() << "load elements from " << *arg << std::endl;
// Logger::cout() << "of type " << *argType << std::endl;
}
for (unsigned int i = 0; i < tf->getNumParams(); i++) {
args[i] = builder.CreateExtractValue(arg, {i});
llvm::Type *t = tf->getParamType(i);
if (t->isPointerTy())
args[i] = builder.CreateIntToPtr(args[i],t);
}

builder.CreateCall(llf->getFunctionType(), llf, args);
builder.CreateRetVoid();
IF_LOG Logger::cout() << *f << std::endl;
}

};
} // anonymous namespace.

DComputeTarget *createVulkanTarget(llvm::LLVMContext &c, int ver) {
return new TargetVulkan(c, ver);
}

#endif // LDC_LLVM_SUPPORTED_TARGET_SPIRV
18 changes: 13 additions & 5 deletions gen/irstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ template <typename F>
LLGlobalVariable *
getCachedStringLiteralImpl(llvm::Module &module,
llvm::StringMap<LLGlobalVariable *> &cache,
llvm::StringRef key, F initFactory) {
llvm::StringRef key,
std::optional< unsigned > addrspace,
F initFactory) {
auto iter = cache.find(key);
if (iter != cache.end()) {
return iter->second;
Expand All @@ -201,7 +203,8 @@ getCachedStringLiteralImpl(llvm::Module &module,

auto gvar =
new LLGlobalVariable(module, constant->getType(), true,
LLGlobalValue::PrivateLinkage, constant, ".str");
LLGlobalValue::PrivateLinkage, constant,
".str", nullptr, llvm::GlobalValue::NotThreadLocal, addrspace);
gvar->setUnnamedAddr(LLGlobalValue::UnnamedAddr::Global);

cache[key] = gvar;
Expand Down Expand Up @@ -241,14 +244,19 @@ LLGlobalVariable *IRState::getCachedStringLiteral(StringExp *se) {
const llvm::StringRef key(reinterpret_cast<const char *>(keyData.ptr),
keyData.length);

return getCachedStringLiteralImpl(module, *cache, key, [se]() {
return getCachedStringLiteralImpl(module, *cache, key, std::nullopt, [se]() {
// null-terminate
return buildStringLiteralConstant(se, se->len + 1);
});
}

LLGlobalVariable *IRState::getCachedStringLiteral(llvm::StringRef s) {
return getCachedStringLiteralImpl(module, cachedStringLiterals, s, [&]() {
LLGlobalVariable *IRState::getCachedStringLiteral(llvm::StringRef s,
std::optional< unsigned > addrspace) {
return getCachedStringLiteralImpl(module,
cachedStringLiterals,
s,
addrspace,
[&]() {
return llvm::ConstantDataArray::getString(context(), s, true);
});
}
Expand Down
2 changes: 1 addition & 1 deletion gen/irstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ struct IRState {
// calls with a StringExp with matching data will return the same variable.
// Exception: ulong[]-typed hex strings (not null-terminated either).
llvm::GlobalVariable *getCachedStringLiteral(StringExp *se);
llvm::GlobalVariable *getCachedStringLiteral(llvm::StringRef s);
llvm::GlobalVariable *getCachedStringLiteral(llvm::StringRef s, std::optional< unsigned > = std::nullopt);

// List of functions with cpu or features attributes overriden by user
std::vector<IrFunction *> targetCpuOrFeaturesOverridden;
Expand Down
1 change: 1 addition & 0 deletions runtime/druntime/src/ldc/dcompute.d
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum ReflectTarget : uint
Host = 0,
OpenCL = 1,
CUDA = 2,
Vulkan = 3,
}
/**
* The pseudo conditional compilation function.
Expand Down
Loading