@@ -556,6 +556,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556
556
return SYCLGenError ();
557
557
OS () << " , " ;
558
558
switch (T->getKind ()) {
559
+ case InlineAsmVectorType::v1:
560
+ OS () << 1 ;
561
+ break ;
559
562
case InlineAsmVectorType::v2:
560
563
OS () << 2 ;
561
564
break ;
@@ -1342,7 +1345,7 @@ class SYCLGen : public SYCLGenBase {
1342
1345
return SYCLGenError ();
1343
1346
1344
1347
// Register sizes for vector elements of A, B, C & D matrices
1345
- int NumVecElements[4 ] = {0 };
1348
+ unsigned NumVecElements[4 ] = {0 };
1346
1349
1347
1350
// Data type used to multiply A & B matrices
1348
1351
std::string MulType;
@@ -1351,8 +1354,8 @@ class SYCLGen : public SYCLGenBase {
1351
1354
if (AType->getKind () == InlineAsmBuiltinType::f16) {
1352
1355
// If A matrix type is f16, then C&D matrix types can only be f16
1353
1356
if (CType->getKind () == AType->getKind ()) {
1354
- NumVecElements[0 ] = 4 ; // A
1355
- NumVecElements[1 ] = 2 ; // B
1357
+ NumVecElements[0 ] = 2 ; // A
1358
+ NumVecElements[1 ] = 4 ; // B
1356
1359
NumVecElements[2 ] = 4 ; // C
1357
1360
NumVecElements[3 ] = 4 ; // D
1358
1361
} else
@@ -1364,23 +1367,23 @@ class SYCLGen : public SYCLGenBase {
1364
1367
if (AType->getKind () == InlineAsmBuiltinType::f16) {
1365
1368
// If A matrix type is f16, then C&D matrix types can only be f16/f32
1366
1369
if (CType->getKind () == AType->getKind ()) {
1367
- NumVecElements[0 ] = 4 ; // A
1370
+ NumVecElements[0 ] = 2 ; // A
1368
1371
NumVecElements[1 ] = 2 ; // B
1369
- NumVecElements[2 ] = 2 ; // C
1372
+ NumVecElements[2 ] = 4 ; // C
1370
1373
NumVecElements[3 ] = 4 ; // D
1371
1374
} else if (CType->getKind () == InlineAsmBuiltinType::f32) {
1372
- NumVecElements[0 ] = 8 ; // A
1375
+ NumVecElements[0 ] = 2 ; // A
1373
1376
NumVecElements[1 ] = 2 ; // B
1374
- NumVecElements[2 ] = 2 ; // C
1377
+ NumVecElements[2 ] = 8 ; // C
1375
1378
NumVecElements[3 ] = 8 ; // D
1376
1379
} else
1377
1380
return SYCLGenError ();
1378
1381
} else if (AType->getKind () == InlineAsmBuiltinType::f64) {
1379
1382
// If A matrix type is f64, then C&D matrix types can only be f64
1380
1383
if (CType->getKind () == AType->getKind ()) {
1381
- NumVecElements[0 ] = 2 ; // A
1384
+ NumVecElements[0 ] = 1 ; // A
1382
1385
NumVecElements[1 ] = 1 ; // B
1383
- NumVecElements[2 ] = 1 ; // C
1386
+ NumVecElements[2 ] = 2 ; // C
1384
1387
NumVecElements[3 ] = 2 ; // D
1385
1388
} else
1386
1389
return SYCLGenError ();
@@ -1411,15 +1414,16 @@ class SYCLGen : public SYCLGenBase {
1411
1414
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement (Inst)))
1412
1415
continue ;
1413
1416
OS () << " &" ;
1414
- if (emitStmt (VE ->getElement (Inst)))
1417
+ if (emitStmt (DMatVE ->getElement (Inst)))
1415
1418
return SYCLGenError ();
1416
1419
OS () << " , " ;
1417
1420
}
1418
1421
1419
1422
// Add A, B & C matrix values to compute MAD
1420
1423
for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1421
1424
InputOp++) {
1422
- if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1425
+ if (auto VE =
1426
+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1423
1427
for (unsigned Inst = 0 ; Inst != VE->getNumElements (); ++Inst) {
1424
1428
if (isa<InlineAsmDiscardExpr>(VE->getElement (Inst)))
1425
1429
continue ;
0 commit comments