Skip to content

Commit d4a821a

Browse files
Added support for ldmatrix migration
1 parent c1cb3cc commit d4a821a

File tree

9 files changed

+260
-24
lines changed

9 files changed

+260
-24
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
#include "AsmMigration.h"
1010
#include "AnalysisInfo.h"
11+
#include "Diagnostics/Diagnostics.h"
12+
#include "ErrorHandle/CrashRecovery.h"
13+
#include "RuleInfra/MapNames.h"
1114
#include "RulesAsm/Parser/AsmNodes.h"
1215
#include "RulesAsm/Parser/AsmParser.h"
1316
#include "RulesAsm/Parser/AsmTokenKinds.h"
14-
#include "ErrorHandle/CrashRecovery.h"
15-
#include "Diagnostics/Diagnostics.h"
16-
#include "RuleInfra/MapNames.h"
1717
#include "TextModification.h"
1818
#include "Utility.h"
1919
#include "clang/AST/Expr.h"
@@ -557,12 +557,15 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
557557
OS() << ", ";
558558
switch (T->getKind()) {
559559
case InlineAsmVectorType::v2:
560+
case InlineAsmVectorType::x1:
560561
OS() << 2;
561562
break;
562563
case InlineAsmVectorType::v4:
564+
case InlineAsmVectorType::x2:
563565
OS() << 4;
564566
break;
565567
case InlineAsmVectorType::v8:
568+
case InlineAsmVectorType::x4:
566569
OS() << 8;
567570
break;
568571
}
@@ -589,9 +592,9 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {
589592

590593
bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
591594
// Address expression only support ld/st/red & atom instructions.
592-
if (!CurrInst ||
593-
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
594-
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
595+
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
596+
asmtok::op_prefetch, asmtok::op_red,
597+
asmtok::op_cp, asmtok::op_ldmatrix)) {
595598
return SYCLGenError();
596599
}
597600
std::string Type;
@@ -624,6 +627,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
624627
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
625628
CanSuppressCast(Dst->getSymbol()))
626629
OS() << llvm::formatv("{0}", Reg);
630+
else if (CurrInst->is(asmtok::op_ldmatrix))
631+
OS() << llvm::formatv("(uintptr_t){0}", Reg);
627632
else
628633
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
629634
break;
@@ -1301,6 +1306,46 @@ class SYCLGen : public SYCLGenBase {
13011306
return SYCLGenSuccess();
13021307
}
13031308

1309+
bool handle_ldmatrix(const InlineAsmInstruction *Inst) override {
1310+
if (Inst->getNumInputOperands() != 1)
1311+
return SYCLGenError();
1312+
1313+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1314+
CurrInst = Inst;
1315+
const auto *Src =
1316+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
1317+
if (!Src)
1318+
return false;
1319+
1320+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::ldmatrix(";
1321+
if (emitStmt(Src)) {
1322+
return SYCLGenError();
1323+
}
1324+
OS() << ", ";
1325+
const auto *VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1326+
for (unsigned Inst = 0, E = VE->getNumElements(); Inst != E; ++Inst) {
1327+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1328+
continue;
1329+
OS() << "&";
1330+
if (emitStmt(VE->getElement(Inst)))
1331+
return SYCLGenError();
1332+
OS() << ", ";
1333+
}
1334+
OS() << DpctGlobalInfo::getItem(GAS);
1335+
if (Inst->hasAttr(InstAttr::trans))
1336+
OS() << ", true";
1337+
OS() << ");";
1338+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1339+
if (KernelDecl) {
1340+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1341+
if (FuncInfo)
1342+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1343+
DpctGlobalInfo::getSubGroup(GAS));
1344+
}
1345+
1346+
return SYCLGenSuccess();
1347+
}
1348+
13041349
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
13051350
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
13061351
return SYCLGenError();
@@ -2715,6 +2760,7 @@ class SYCLGen : public SYCLGenBase {
27152760
bool handle_ld(const InlineAsmInstruction *Inst) override {
27162761
if (Inst->getNumInputOperands() != 1)
27172762
return SYCLGenError();
2763+
27182764
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
27192765
CurrInst = Inst;
27202766
const auto *Src =

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v2, v4, v8, x1, x2, x4 };
120120

121121
private:
122122
VecKind Kind;
@@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
340340
/// therest are input operands.
341341
SmallVector<InlineAsmExpr *, 4> InputOps;
342342

343+
SmallVector<InlineAsmExpr *, 4> OutputOps;
344+
343345
public:
344346
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345347
SmallVector<AsmStateSpace, 4> AsmStateSpaces,

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
327327
if (!Tok.getIdentifier() || !Tok.getIdentifier()->isInstruction())
328328
return AsmStmtError();
329329

330-
InlineAsmIdentifierInfo *Opcode = Tok.getIdentifier();
330+
Opcode = Tok.getIdentifier();
331331
ConsumeToken();
332332

333333
SmallVector<InstAttr, 4> Attrs;
@@ -736,20 +736,38 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size must be 2, 4, or 8.
739+
// Vector size for ldmatrix are 1, 2, 4
740+
// size(x) = 2 * sizeof(v).
740741
InlineAsmVectorType::VecKind Kind;
741-
switch (Vec.size()) {
742-
case 2:
743-
Kind = InlineAsmVectorType::v2;
744-
break;
745-
case 4:
746-
Kind = InlineAsmVectorType::v4;
747-
break;
748-
case 8:
749-
Kind = InlineAsmVectorType::v8;
750-
break;
751-
default:
752-
return AsmExprError();
742+
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
743+
switch (Vec.size()) {
744+
case 1:
745+
Kind = InlineAsmVectorType::x1;
746+
break;
747+
case 2:
748+
Kind = InlineAsmVectorType::x2;
749+
break;
750+
case 4:
751+
Kind = InlineAsmVectorType::x4;
752+
break;
753+
default:
754+
return AsmExprError();
755+
}
756+
} else {
757+
// Vector size must be 2, 4, or 8.
758+
switch (Vec.size()) {
759+
case 2:
760+
Kind = InlineAsmVectorType::v2;
761+
break;
762+
case 4:
763+
Kind = InlineAsmVectorType::v4;
764+
break;
765+
case 8:
766+
Kind = InlineAsmVectorType::v8;
767+
break;
768+
default:
769+
return AsmExprError();
770+
}
753771
}
754772

755773
InlineAsmBuiltinType *ElementType = nullptr;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ class InlineAsmParser {
247247
};
248248

249249
public:
250+
InlineAsmIdentifierInfo *Opcode;
251+
250252
InlineAsmParser(InlineAsmContext &Ctx, SourceMgr &Mgr)
251253
: Lexer(*Mgr.getMemoryBuffer(Mgr.getMainFileID())), Context(Ctx),
252254
SrcMgr(Mgr), CurScope(nullptr) {

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ MODIFIER(v2, ".v2")
274274
MODIFIER(v4, ".v4")
275275
MODIFIER(v8, ".v8")
276276

277+
// Matrix modifiers
278+
MODIFIER(x1, ".x1")
279+
MODIFIER(x2, ".x2")
280+
MODIFIER(x4, ".x4")
281+
282+
// Matrix shape
283+
MODIFIER(m8n8, ".m8n8")
284+
277285
STATE_SPACE(reg, ".reg")
278286
STATE_SPACE(sreg, ".sreg")
279287
STATE_SPACE(const, ".const")
@@ -418,6 +426,8 @@ MODIFIER(rc8, ".rc8")
418426
MODIFIER(ecl, ".ecl")
419427
MODIFIER(ecr, ".ecr")
420428
MODIFIER(rc16, ".rc16")
429+
MODIFIER(aligned, ".aligned")
430+
MODIFIER(trans, ".trans")
421431

422432
#undef LINKAGE
423433
#undef TARGET

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ENTRY("griddepcontrol", "griddepcontrol", false, NO_FLAG, P1, "Comment")
7575
ENTRY("isspacep", "isspacep", false, NO_FLAG, P1, "Comment")
7676
ENTRY("istypep", "istypep", false, NO_FLAG, P1, "Comment")
7777
ENTRY("ld", "ld", true, NO_FLAG, P1, "Partial")
78-
ENTRY("ldmatrix", "ldmatrix", false, NO_FLAG, P1, "Comment")
78+
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Successful")
7979
ENTRY("ldu", "ldu", false, NO_FLAG, P1, "Comment")
8080
ENTRY("lg2", "lg2", true, NO_FLAG, P1, "Successful")
8181
ENTRY("lop3", "lop3", true, NO_FLAG, P1, "Successful")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#ifndef __DPCT_MATH_HPP__
1010
#define __DPCT_MATH_HPP__
1111

12-
#include <limits>
1312
#include <climits>
13+
#include <limits>
1414
#include <sycl/sycl.hpp>
1515
#include <type_traits>
1616

@@ -2055,6 +2055,64 @@ class joint_matrix {
20552055
matrix_accessor x;
20562056
const size_t num_elements;
20572057
};
2058+
2059+
template <typename T>
2060+
void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
2061+
bool trans = false, unsigned mat = 0) {
2062+
int lane = item_ct1.get_local_id(2);
2063+
2064+
int group = lane / 8;
2065+
int sub = lane % 8;
2066+
int src_base = group * 2;
2067+
2068+
if (!trans) {
2069+
// calculate the source lane
2070+
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2071+
2072+
// Broadcast the address from the source lane
2073+
auto recv_addr_uintp = dpct::select_from_sub_group(
2074+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2075+
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
2076+
2077+
// Non-transposed load
2078+
*m = recv_addr[sub % 4];
2079+
} else {
2080+
// calculate the source lane
2081+
int src_lane = (lane % 4) * 2;
2082+
2083+
// Broadcast the address from the source lane:
2084+
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2085+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2086+
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2087+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2088+
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
2089+
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
2090+
2091+
// Transposed load
2092+
int index = (lane / 4);
2093+
sycl::half val0 = recv_addr_1[index];
2094+
sycl::half val1 = recv_addr_2[index];
2095+
sycl::half2 val = sycl::half2(val0, val1);
2096+
*m = *reinterpret_cast<T *>(&val);
2097+
}
2098+
}
2099+
2100+
template <typename T>
2101+
void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
2102+
bool trans = false) {
2103+
ldmatrix(addr, m1, item_ct1, trans, 0);
2104+
ldmatrix(addr, m2, item_ct1, trans, 1);
2105+
}
2106+
2107+
template <typename T>
2108+
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
2109+
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2110+
ldmatrix(addr, m1, item_ct1, trans, 0);
2111+
ldmatrix(addr, m2, item_ct1, trans, 1);
2112+
ldmatrix(addr, m3, item_ct1, trans, 2);
2113+
ldmatrix(addr, m4, item_ct1, trans, 3);
2114+
}
2115+
20582116
} // namespace matrix
20592117
} // namespace experimental
20602118

0 commit comments

Comments
 (0)