Skip to content

Commit e622ab9

Browse files
Added support for m8n8k4
1 parent 03f180c commit e622ab9

File tree

6 files changed

+37
-28
lines changed

6 files changed

+37
-28
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556
return SYCLGenError();
557557
OS() << ", ";
558558
switch (T->getKind()) {
559+
case InlineAsmVectorType::v1:
560+
OS() << 1;
561+
break;
559562
case InlineAsmVectorType::v2:
560563
OS() << 2;
561564
break;
@@ -1342,7 +1345,7 @@ class SYCLGen : public SYCLGenBase {
13421345
return SYCLGenError();
13431346

13441347
// Register sizes for vector elements of A, B, C & D matrices
1345-
int NumVecElements[4] = {0};
1348+
unsigned NumVecElements[4] = {0};
13461349

13471350
// Data type used to multiply A & B matrices
13481351
std::string MulType;
@@ -1351,8 +1354,8 @@ class SYCLGen : public SYCLGenBase {
13511354
if (AType->getKind() == InlineAsmBuiltinType::f16) {
13521355
// If A matrix type is f16, then C&D matrix types can only be f16
13531356
if (CType->getKind() == AType->getKind()) {
1354-
NumVecElements[0] = 4; // A
1355-
NumVecElements[1] = 2; // B
1357+
NumVecElements[0] = 2; // A
1358+
NumVecElements[1] = 4; // B
13561359
NumVecElements[2] = 4; // C
13571360
NumVecElements[3] = 4; // D
13581361
} else
@@ -1364,23 +1367,23 @@ class SYCLGen : public SYCLGenBase {
13641367
if (AType->getKind() == InlineAsmBuiltinType::f16) {
13651368
// If A matrix type is f16, then C&D matrix types can only be f16/f32
13661369
if (CType->getKind() == AType->getKind()) {
1367-
NumVecElements[0] = 4; // A
1370+
NumVecElements[0] = 2; // A
13681371
NumVecElements[1] = 2; // B
1369-
NumVecElements[2] = 2; // C
1372+
NumVecElements[2] = 4; // C
13701373
NumVecElements[3] = 4; // D
13711374
} else if (CType->getKind() == InlineAsmBuiltinType::f32) {
1372-
NumVecElements[0] = 8; // A
1375+
NumVecElements[0] = 2; // A
13731376
NumVecElements[1] = 2; // B
1374-
NumVecElements[2] = 2; // C
1377+
NumVecElements[2] = 8; // C
13751378
NumVecElements[3] = 8; // D
13761379
} else
13771380
return SYCLGenError();
13781381
} else if (AType->getKind() == InlineAsmBuiltinType::f64) {
13791382
// If A matrix type is f64, then C&D matrix types can only be f64
13801383
if (CType->getKind() == AType->getKind()) {
1381-
NumVecElements[0] = 2; // A
1384+
NumVecElements[0] = 1; // A
13821385
NumVecElements[1] = 1; // B
1383-
NumVecElements[2] = 1; // C
1386+
NumVecElements[2] = 2; // C
13841387
NumVecElements[3] = 2; // D
13851388
} else
13861389
return SYCLGenError();
@@ -1411,15 +1414,16 @@ class SYCLGen : public SYCLGenBase {
14111414
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement(Inst)))
14121415
continue;
14131416
OS() << "&";
1414-
if (emitStmt(VE->getElement(Inst)))
1417+
if (emitStmt(DMatVE->getElement(Inst)))
14151418
return SYCLGenError();
14161419
OS() << ", ";
14171420
}
14181421

14191422
// Add A, B & C matrix values to compute MAD
14201423
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
14211424
InputOp++) {
1422-
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1425+
if (auto VE =
1426+
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
14231427
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
14241428
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
14251429
continue;

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v1, v2, v4, v8 };
120120

121121
private:
122122
VecKind Kind;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
739739
// Vector size must be 2, 4, or 8.
740740
InlineAsmVectorType::VecKind Kind;
741741
switch (Vec.size()) {
742+
case 1:
743+
Kind = InlineAsmVectorType::v1;
744+
break;
742745
case 2:
743746
Kind = InlineAsmVectorType::v2;
744747
break;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ class InlineAsmParser {
496496
/// .reg .sreg .const .local .param .shared .tex
497497
///
498498
/// vector-specifier: one of
499-
/// .v2 .v4 .v8
499+
/// .v1 .v2 .v4 .v8
500500
///
501501
/// type-specifier: one of
502502
/// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

+1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ BUILTIN_TYPE(s16x2, ".s16x2")
270270
BUILTIN_TYPE(u16x2, ".u16x2")
271271

272272
// Vector modifiers
273+
MODIFIER(v1, ".v1")
273274
MODIFIER(v2, ".v2")
274275
MODIFIER(v4, ".v4")
275276
MODIFIER(v8, ".v8")

clang/runtime/dpct-rt/include/dpct/math.hpp

+16-15
Original file line numberDiff line numberDiff line change
@@ -2086,7 +2086,8 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
20862086

20872087
short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4);
20882088

2089-
ABType recv_a[2] recv_a[0] = a0;
2089+
ABType recv_a[2];
2090+
recv_a[0] = a0;
20902091
recv_a[1] = a1;
20912092

20922093
MulType *ra = reinterpret_cast<MulType *>(recv_a);
@@ -2117,10 +2118,10 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
21172118
}
21182119
}
21192120

2120-
d0 = c0;
2121-
d1 = c1;
2122-
d2 = c2;
2123-
d3 = c3;
2121+
*d0 = c0;
2122+
*d1 = c1;
2123+
*d2 = c2;
2124+
*d3 = c3;
21242125
}
21252126

21262127
/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32
@@ -2202,14 +2203,14 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5,
22022203
c7 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 12]);
22032204
}
22042205

2205-
d0 = c0;
2206-
d1 = c1;
2207-
d2 = c2;
2208-
d3 = c3;
2209-
d4 = c4;
2210-
d5 = c5;
2211-
d6 = c6;
2212-
d7 = c7;
2206+
*d0 = c0;
2207+
*d1 = c1;
2208+
*d2 = c2;
2209+
*d3 = c3;
2210+
*d4 = c4;
2211+
*d5 = c5;
2212+
*d6 = c6;
2213+
*d7 = c7;
22132214
}
22142215

22152216
/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32
@@ -2246,8 +2247,8 @@ void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1,
22462247
c1 += recv_a * recv_b;
22472248
}
22482249

2249-
d0 = c0;
2250-
d1 = c1;
2250+
*d0 = c0;
2251+
*d1 = c1;
22512252
}
22522253

22532254
/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32

0 commit comments

Comments
 (0)