Skip to content

Commit b8a936f

Browse files
Updated logic
1 parent c52860e commit b8a936f

File tree

2 files changed

+38
-32
lines changed

2 files changed

+38
-32
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

+17-11
Original file line numberDiff line numberDiff line change
@@ -556,16 +556,18 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
556556
return SYCLGenError();
557557
OS() << ", ";
558558
switch (T->getKind()) {
559-
case InlineAsmVectorType::v2:
560559
case InlineAsmVectorType::x1:
560+
OS() << 1;
561+
break;
562+
case InlineAsmVectorType::v2:
563+
case InlineAsmVectorType::x2:
561564
OS() << 2;
562565
break;
563566
case InlineAsmVectorType::v4:
564-
case InlineAsmVectorType::x2:
567+
case InlineAsmVectorType::x4:
565568
OS() << 4;
566569
break;
567570
case InlineAsmVectorType::v8:
568-
case InlineAsmVectorType::x4:
569571
OS() << 8;
570572
break;
571573
}
@@ -1322,14 +1324,18 @@ class SYCLGen : public SYCLGenBase {
13221324
return SYCLGenError();
13231325
}
13241326
OS() << ", ";
1325-
const auto *VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1326-
for (unsigned Inst = 0, E = VE->getNumElements(); Inst != E; ++Inst) {
1327-
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1328-
continue;
1329-
OS() << "&";
1330-
if (emitStmt(VE->getElement(Inst)))
1331-
return SYCLGenError();
1332-
OS() << ", ";
1327+
if (const auto *VE =
1328+
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand())) {
1329+
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1330+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1331+
continue;
1332+
OS() << "&";
1333+
if (emitStmt(VE->getElement(Inst)))
1334+
return SYCLGenError();
1335+
OS() << ", ";
1336+
}
1337+
} else {
1338+
return SYCLGenError();
13331339
}
13341340
OS() << DpctGlobalInfo::getItem(GAS);
13351341
if (Inst->hasAttr(InstAttr::trans))

clang/runtime/dpct-rt/include/dpct/math.hpp

+21-21
Original file line numberDiff line numberDiff line change
@@ -2061,13 +2061,13 @@ class joint_matrix {
20612061
/// \tparam [in] T The type of result variable
20622062
/// \param [in] addr The address of the matrix in shared memory
20632063
/// \param [in] m The local memory to store the matrix
2064-
/// \param [in] item_ct1 The sycl::nd_item object
2064+
/// \param [in] item The sycl::nd_item index space class
20652065
/// \param [in] trans Indicates whether the matrix to be loaded transposed
20662066
/// \param [in] mat The matrix index to be loaded
2067-
template <typename T>
2068-
void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
2069-
bool trans = false, unsigned mat = 0) {
2070-
int lane = item_ct1.get_local_id(2) % 32;
2067+
template <typename T, typename ItemT>
2068+
void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
2069+
unsigned mat = 0) {
2070+
int lane = item.get_sub_group().get_local_linear_id();
20712071

20722072
int lane_group8_row = lane / 8;
20732073
int lane_group8_col = lane % 8;
@@ -2080,7 +2080,7 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20802080

20812081
// Broadcast the address from the source lane
20822082
auto recv_addr_uintp = dpct::select_from_sub_group(
2083-
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2083+
item.get_sub_group(), addr, mat * 8 + src_lane);
20842084

20852085
// Cast the received address from uintptr_t to the type of 'm'
20862086
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
@@ -2093,9 +2093,9 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20932093

20942094
// Broadcast the address from the source lane
20952095
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2096-
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2096+
item.get_sub_group(), addr, mat * 8 + src_lane);
20972097
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2098-
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2098+
item.get_sub_group(), addr, mat * 8 + src_lane + 1);
20992099

21002100
// Cast the received address from uintptr_t to 'half *'
21012101
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
@@ -2118,15 +2118,15 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
21182118
/// \param [in] addr The address of the matrix in shared memory
21192119
/// \param [in] m1 The local memory to store data of 1st matrix
21202120
/// \param [in] m2 The local memory to store data of 2nd matrix
2121-
/// \param [in] item_ct1 The sycl::nd_item object
2121+
/// \param [in] item The sycl::nd_item index space class
21222122
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2123-
template <typename T>
2124-
void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
2123+
template <typename T, typename ItemT>
2124+
void ldmatrix(uintptr_t addr, T *m1, T *m2, const ItemT &item,
21252125
bool trans = false) {
21262126
// Load 1st matrix
2127-
ldmatrix(addr, m1, item_ct1, trans, 0);
2127+
ldmatrix(addr, m1, item, trans, 0);
21282128
// Load 2nd matrix
2129-
ldmatrix(addr, m2, item_ct1, trans, 1);
2129+
ldmatrix(addr, m2, item, trans, 1);
21302130
}
21312131

21322132
/// Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
@@ -2137,19 +2137,19 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
21372137
/// \param [in] m2 The local memory to store data of 2nd matrix
21382138
/// \param [in] m3 The local memory to store data of 3rd matrix
21392139
/// \param [in] m4 The local memory to store data of 4th matrix
2140-
/// \param [in] item_ct1 The sycl::nd_item object
2140+
/// \param [in] item The sycl::nd_item index space class
21412141
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2142-
template <typename T>
2143-
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
2144-
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2142+
template <typename T, typename ItemT>
2143+
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, const ItemT &item,
2144+
bool trans = false) {
21452145
// Load 1st matrix
2146-
ldmatrix(addr, m1, item_ct1, trans, 0);
2146+
ldmatrix(addr, m1, item, trans, 0);
21472147
// Load 2nd matrix
2148-
ldmatrix(addr, m2, item_ct1, trans, 1);
2148+
ldmatrix(addr, m2, item, trans, 1);
21492149
// Load 3rd matrix
2150-
ldmatrix(addr, m3, item_ct1, trans, 2);
2150+
ldmatrix(addr, m3, item, trans, 2);
21512151
// Load 4th matrix
2152-
ldmatrix(addr, m4, item_ct1, trans, 3);
2152+
ldmatrix(addr, m4, item, trans, 3);
21532153
}
21542154

21552155
} // namespace matrix

0 commit comments

Comments
 (0)