Skip to content

Commit eb2f1bf

Browse files
committed
[Dcompute] Initial support for Vulkan
1 parent 46bbe8b commit eb2f1bf

File tree

8 files changed

+225
-7
lines changed

8 files changed

+225
-7
lines changed

driver/dcomputecodegenerator.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ DComputeCodeGenManager::~DComputeCodeGenManager() {}
2828

2929
DComputeTarget *
3030
DComputeCodeGenManager::createComputeTarget(const std::string &s) {
31+
if (s.substr(0, 6) == "vulkan") {
32+
#if LDC_LLVM_SUPPORTED_TARGET_SPIRV && LDC_LLVM_VER >= 2100
33+
//TODO version this for vulkan 1.3/1.4
34+
return createVulkanTarget(ctx, 0);
35+
#else
36+
error(Loc(), "LDC was not built with Vulkan DCompute support.");
37+
#endif
38+
}
3139
if (s.substr(0, 4) == "ocl-") {
3240
#if LDC_LLVM_SUPPORTED_TARGET_SPIRV
3341
#define OCL_VALID_VER_INIT 100, 110, 120, 200, 210, 220

gen/abi/spirv.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,12 @@ struct SPIRVTargetABI : TargetABI {
5353
};
5454

5555
TargetABI *createSPIRVABI() { return new SPIRVTargetABI(); }
56+
57+
struct SPIRVVulkanTargetABI : SPIRVTargetABI {
58+
59+
llvm::CallingConv::ID callingConv(FuncDeclaration *fdecl) override {
60+
// The synthesised wrapper is SPIR_KERNEL
61+
return llvm::CallingConv::SPIR_FUNC;
62+
}
63+
};
64+
TargetABI *createSPIRVVulkanABI() { return new SPIRVVulkanTargetABI(); }

gen/abi/targets.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ TargetABI *getRISCV64TargetABI();
3131

3232
TargetABI *createSPIRVABI();
3333

34+
TargetABI *createSPIRVVulkanABI();
35+
3436
TargetABI *getWin64TargetABI();
3537

3638
TargetABI *getX86_64TargetABI();

gen/dcompute/target.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class DComputeTarget {
2727
public:
2828
llvm::LLVMContext &ctx;
2929
int tversion; // OpenCL or CUDA CC version:major*100 + minor*10
30-
enum class ID { Host = 0, OpenCL = 1, CUDA = 2 };
30+
enum class ID { Host = 0, OpenCL = 1, CUDA = 2, Vulkan = 3 };
3131
ID target; // ID for codegen time conditional compilation.
3232
const char *short_name;
3333
const char *binSuffix;
@@ -60,4 +60,7 @@ DComputeTarget *createCUDATarget(llvm::LLVMContext &c, int sm);
6060

6161
#if LDC_LLVM_SUPPORTED_TARGET_SPIRV
6262
DComputeTarget *createOCLTarget(llvm::LLVMContext &c, int oclver);
63+
#if LDC_LLVM_VER >= 2100
64+
DComputeTarget *createVulkanTarget(llvm::LLVMContext &c, int ver);
65+
#endif
6366
#endif

gen/dcompute/targetVulkan.cpp

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===-- gen/dcomputetargetOCL.cpp -----------------------------------------===//
2+
//
3+
// LDC – the LLVM D compiler
4+
//
5+
// Parts of this file are adapted from CodeGenFunction.cpp (Clang, LLVM).
6+
// Therefore, this file is distributed under the LLVM license.
7+
// See the LICENSE file for details.
8+
//===----------------------------------------------------------------------===//
9+
10+
#if LDC_LLVM_SUPPORTED_TARGET_SPIRV && LDC_LLVM_VER >= 2100
11+
12+
#include "dmd/id.h"
13+
#include "dmd/identifier.h"
14+
#include "dmd/template.h"
15+
#include "dmd/mangle.h"
16+
#include "dmd/module.h"
17+
#include "gen/abi/targets.h"
18+
#include "gen/dcompute/target.h"
19+
#include "gen/dcompute/druntime.h"
20+
#include "gen/logger.h"
21+
#include "gen/optimizer.h"
22+
#include "driver/targetmachine.h"
23+
#include "llvm/Transforms/Scalar.h"
24+
#include "llvm/Target/TargetMachine.h"
25+
#include <cstring>
26+
#include <string>
27+
28+
using namespace dmd;
29+
30+
namespace {
31+
class TargetVulkan : public DComputeTarget {
32+
public:
33+
TargetVulkan(llvm::LLVMContext &c, int ver)
34+
: DComputeTarget(c, ver, ID::Vulkan, "vulkan", "spv", createSPIRVVulkanABI(),
35+
{{0, 1, 2, 3, 4}}) {
36+
37+
_ir = new IRState("dcomputeTargetVulkan", ctx);
38+
// "spirv-vulkan-foo"? foo = library, pixel, etc
39+
std::string targTriple = "spirv1.6-unknown-vulkan1.3-compute";
40+
_ir->module.setTargetTriple(llvm::Triple(targTriple));
41+
42+
auto floatABI = ::FloatABI::Hard;
43+
targetMachine = createTargetMachine(
44+
targTriple, "spirv", "", {},
45+
ExplicitBitness::None, floatABI,
46+
llvm::Reloc::Static, llvm::CodeModel::Medium, codeGenOptLevel(), false);
47+
48+
_ir->module.setDataLayout(targetMachine->createDataLayout());
49+
50+
_ir->dcomputetarget = this;
51+
}
52+
53+
void addMetadata() override {}
54+
55+
llvm::AttrBuilder buildKernAttrs(StructLiteralExp *kernAttr) {
56+
auto b = llvm::AttrBuilder(ctx);
57+
b.addAttribute("hlsl.shader", "compute");
58+
Expressions* elts = static_cast<ArrayLiteralExp*>((*(kernAttr->elements))[0])->elements;
59+
std::string numthreads = "";
60+
numthreads += std::to_string((*elts)[0]->toInteger()) + ",";
61+
numthreads += std::to_string((*elts)[1]->toInteger()) + ",";
62+
numthreads += std::to_string((*elts)[2]->toInteger());
63+
64+
b.addAttribute("hlsl.numthreads", numthreads);
65+
// ? "hlsl.wavesize"="8,128,64"
66+
// ? "hlsl.export"
67+
return b;
68+
}
69+
llvm::Function *buildFunction(FuncDeclaration *fd) {
70+
auto *void_func_void = llvm::FunctionType::get(llvm::Type::getVoidTy(ctx),{}, false);
71+
auto linkage = llvm::GlobalValue::LinkageTypes::ExternalLinkage;
72+
auto name = llvm::Twine(mangleExact(fd)) + llvm::Twine("_kernel");
73+
auto *f = llvm::Function::Create(void_func_void, linkage, name, _ir->module);
74+
f->setCallingConv(llvm::CallingConv::SPIR_KERNEL);
75+
return f;
76+
}
77+
llvm::Type *buildArgType(llvm::Function *llf, llvm::SmallVector<llvm::Type *, 8> &args, llvm::StringRef name) {
78+
IF_LOG {
79+
Logger::cout() << "buildArgType: " << *llf << std::endl;
80+
}
81+
llvm::FunctionType *tf = llf->getFunctionType();
82+
for (unsigned int i = 0; i < tf->getNumParams(); i++) {
83+
llvm::Type *t = tf->getParamType(i);
84+
if (t->isPointerTy())
85+
t = getI64Type(); // FIXME: 32 bit pointers on 32 but systems?
86+
args[i] = t;
87+
}
88+
89+
IF_LOG {
90+
for (auto *arg : args) {
91+
Logger::cout() << *arg;
92+
}
93+
}
94+
return llvm::StructType::create(ctx, args, name);
95+
}
96+
llvm::TargetExtType *buildTargetType(llvm::Type *argType) {
97+
// TODO: Do we need to bother with a "spirv.Layout" here?
98+
//auto *layout = llvm::TargetExtType::get(ctx, "spirv.Layout", ElemType,{});
99+
auto * ArrayType = llvm::ArrayType::get(argType, 0);
100+
return llvm::TargetExtType::get(ctx, "spirv.VulkanBuffer",
101+
{ArrayType},
102+
{12/*StorageClass*/, 0 /*isWritable*/});
103+
}
104+
105+
llvm::Value *buildIntrinsicCall(IRBuilder<>& builder, llvm::StringRef dbg,llvm::StringRef name,
106+
llvm::ArrayRef<llvm::Type *> types, llvm::ArrayRef<llvm::Value *> args) {
107+
IF_LOG {
108+
Logger::println("buildIntrinsicCall: %s", name.data());
109+
}
110+
LOG_SCOPE
111+
llvm::Function *intrinsic = llvm::Intrinsic::getOrInsertDeclaration(&_ir->module,
112+
llvm::Intrinsic::lookupIntrinsicID(name),
113+
types);
114+
IF_LOG {
115+
Logger::cout() << "intrinsic = " << *intrinsic << std::endl;
116+
Logger::println("args:");
117+
LOG_SCOPE
118+
for (auto* arg : args) {
119+
Logger::cout() << *arg << std::endl;
120+
}
121+
}
122+
123+
return builder.CreateCall(intrinsic->getFunctionType(), intrinsic, args, dbg);
124+
}
125+
126+
void addKernelMetadata(FuncDeclaration *fd, llvm::Function *llf, StructLiteralExp *kernAttr) override {
127+
// Fake being HLSL
128+
llvm::Function *f = buildFunction(fd);
129+
f->addFnAttrs(buildKernAttrs(kernAttr));
130+
131+
llvm::SmallVector<llvm::Type *, 8> argTypes(llf->getFunctionType()->getNumParams());
132+
auto name = llvm::Twine(mangleExact(fd)) + llvm::Twine("_args");
133+
auto *argType = buildArgType(llf, argTypes, name.str());
134+
llvm::Type *targetType = buildTargetType(argType);
135+
136+
auto bb = llvm::BasicBlock::Create(ctx, "", f);
137+
llvm::IRBuilder<> builder(ctx);
138+
builder.SetInsertPoint(bb);
139+
140+
llvm::Value *i32zero = llvm::ConstantInt::get(getI32Type(), 0, false);
141+
llvm::Value *i32one = llvm::ConstantInt::get(getI32Type(), 1, false);
142+
llvm::Value *i1false = llvm::ConstantInt::get(llvm::Type::getInt1Ty(ctx), 0, false);
143+
144+
// We can't use `DtoConstCString` here because it ends up in the wrong address space, So we use
145+
// `getCachedStringLiteral` directly with an explicitly supplied addrspace of `0`.
146+
// FIXME: call should have `notnull` attribute on pointer?
147+
auto *handle = buildIntrinsicCall(builder, "handle","llvm.spv.resource.handlefrombinding",
148+
{targetType},
149+
{i32zero, i32zero, i32one, i32zero, i1false, _ir->getCachedStringLiteral(name.str(), 0) });
150+
auto *p11 = llvm::PointerType::get(ctx, 11);
151+
auto *pointer = buildIntrinsicCall(builder, "pointer", "llvm.spv.resource.getpointer",
152+
{p11, targetType}, {handle, i32one});
153+
llvm::FunctionType *tf = llf->getFunctionType();
154+
IF_LOG {
155+
Logger::cout() << "load pointer: " << *pointer << std::endl;
156+
Logger::cout() << _ir->module.getDataLayout().getABITypeAlign(argType).value() << std::endl;
157+
Logger::cout() << tf->getParamType(0)->getTypeID() << std::endl;
158+
Logger::cout() << "done" << std::endl;
159+
}
160+
LOG_SCOPE
161+
llvm::SmallVector<llvm::Value *, 8> args(tf->getNumParams());
162+
163+
auto *arg = builder.CreateAlignedLoad(argType, pointer, _ir->module.getDataLayout().getABITypeAlign(argType), false);
164+
IF_LOG {
165+
// Logger::cout() << "load elements from " << *arg << std::endl;
166+
// Logger::cout() << "of type " << *argType << std::endl;
167+
}
168+
for (unsigned int i = 0; i < tf->getNumParams(); i++) {
169+
args[i] = builder.CreateExtractValue(arg, {i});
170+
llvm::Type *t = tf->getParamType(i);
171+
if (t->isPointerTy())
172+
args[i] = builder.CreateIntToPtr(args[i],t);
173+
}
174+
175+
builder.CreateCall(llf->getFunctionType(), llf, args);
176+
builder.CreateRetVoid();
177+
IF_LOG Logger::cout() << *f << std::endl;
178+
}
179+
180+
};
181+
} // anonymous namespace.
182+
183+
DComputeTarget *createVulkanTarget(llvm::LLVMContext &c, int ver) {
184+
return new TargetVulkan(c, ver);
185+
}
186+
187+
#endif // LDC_LLVM_SUPPORTED_TARGET_SPIRV

gen/irstate.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ template <typename F>
191191
LLGlobalVariable *
192192
getCachedStringLiteralImpl(llvm::Module &module,
193193
llvm::StringMap<LLGlobalVariable *> &cache,
194-
llvm::StringRef key, F initFactory) {
194+
llvm::StringRef key,
195+
std::optional< unsigned > addrspace,
196+
F initFactory) {
195197
auto iter = cache.find(key);
196198
if (iter != cache.end()) {
197199
return iter->second;
@@ -201,7 +203,8 @@ getCachedStringLiteralImpl(llvm::Module &module,
201203

202204
auto gvar =
203205
new LLGlobalVariable(module, constant->getType(), true,
204-
LLGlobalValue::PrivateLinkage, constant, ".str");
206+
LLGlobalValue::PrivateLinkage, constant,
207+
".str", nullptr, llvm::GlobalValue::NotThreadLocal, addrspace);
205208
gvar->setUnnamedAddr(LLGlobalValue::UnnamedAddr::Global);
206209

207210
cache[key] = gvar;
@@ -241,14 +244,19 @@ LLGlobalVariable *IRState::getCachedStringLiteral(StringExp *se) {
241244
const llvm::StringRef key(reinterpret_cast<const char *>(keyData.ptr),
242245
keyData.length);
243246

244-
return getCachedStringLiteralImpl(module, *cache, key, [se]() {
247+
return getCachedStringLiteralImpl(module, *cache, key, std::nullopt, [se]() {
245248
// null-terminate
246249
return buildStringLiteralConstant(se, se->len + 1);
247250
});
248251
}
249252

250-
LLGlobalVariable *IRState::getCachedStringLiteral(llvm::StringRef s) {
251-
return getCachedStringLiteralImpl(module, cachedStringLiterals, s, [&]() {
253+
LLGlobalVariable *IRState::getCachedStringLiteral(llvm::StringRef s,
254+
std::optional< unsigned > addrspace) {
255+
return getCachedStringLiteralImpl(module,
256+
cachedStringLiterals,
257+
s,
258+
addrspace,
259+
[&]() {
252260
return llvm::ConstantDataArray::getString(context(), s, true);
253261
});
254262
}

gen/irstate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ struct IRState {
248248
// calls with a StringExp with matching data will return the same variable.
249249
// Exception: ulong[]-typed hex strings (not null-terminated either).
250250
llvm::GlobalVariable *getCachedStringLiteral(StringExp *se);
251-
llvm::GlobalVariable *getCachedStringLiteral(llvm::StringRef s);
251+
llvm::GlobalVariable *getCachedStringLiteral(llvm::StringRef s, std::optional< unsigned > = std::nullopt);
252252

253253
// List of functions with cpu or features attributes overriden by user
254254
std::vector<IrFunction *> targetCpuOrFeaturesOverridden;

runtime/druntime/src/ldc/dcompute.d

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ enum ReflectTarget : uint
1414
Host = 0,
1515
OpenCL = 1,
1616
CUDA = 2,
17+
Vulkan = 3,
1718
}
1819
/**
1920
* The pseudo conditional compilation function.

0 commit comments

Comments
 (0)