Skip to content

Commit ea65402

Browse files
Removed item_ct1 in favor of free functions
1 parent 3051e84 commit ea65402

File tree

3 files changed

+27
-35
lines changed

3 files changed

+27
-35
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -1344,16 +1344,13 @@ class SYCLGen : public SYCLGenBase {
13441344
if (emitStmt(Src)) {
13451345
return SYCLGenError();
13461346
}
1347-
OS() << ", ";
13481347
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
13491348
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
13501349
continue;
1351-
OS() << "&";
1350+
OS() << ", &";
13521351
if (emitStmt(VE->getElement(Inst)))
13531352
return SYCLGenError();
1354-
OS() << ", ";
13551353
}
1356-
OS() << DpctGlobalInfo::getItem(GAS);
13571354
if (Inst->hasAttr(InstAttr::trans))
13581355
OS() << ", true";
13591356
OS() << ");";

clang/runtime/dpct-rt/include/dpct/math.hpp

+20-25
Original file line numberDiff line numberDiff line change
@@ -2061,13 +2061,12 @@ class joint_matrix {
20612061
/// \tparam [in] T The type of result variable
20622062
/// \param [in] addr The address of the matrix in local memory
20632063
/// \param [in] m The private memory to store the matrix
2064-
/// \param [in] item The sycl::nd_item index space class
20652064
/// \param [in] trans Indicates whether the matrix to be loaded transposed
20662065
/// \param [in] mat The matrix index to be loaded
2067-
template <typename T, typename ItemT>
2068-
void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
2069-
unsigned mat = 0) {
2070-
int lane = item.get_sub_group().get_local_linear_id();
2066+
template <typename T>
2067+
void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
2068+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
2069+
int lane = sg.get_local_linear_id();
20712070

20722071
int lane_group8_row = lane / 8;
20732072
int lane_group8_col = lane % 8;
@@ -2079,8 +2078,8 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
20792078
src_lane += 1;
20802079

20812080
// Broadcast the address from the source lane
2082-
auto recv_addr_uintp = dpct::select_from_sub_group(
2083-
item.get_sub_group(), addr, mat * 8 + src_lane);
2081+
auto recv_addr_uintp =
2082+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
20842083

20852084
// Cast the received address from uintptr_t to the type of 'm'
20862085
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
@@ -2092,10 +2091,10 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
20922091
int src_lane = (lane % 4) * 2;
20932092

20942093
// Broadcast the address from the source lane
2095-
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2096-
item.get_sub_group(), addr, mat * 8 + src_lane);
2097-
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2098-
item.get_sub_group(), addr, mat * 8 + src_lane + 1);
2094+
auto recv_addr_uintp_1 =
2095+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
2096+
auto recv_addr_uintp_2 =
2097+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
20992098

21002099
// Cast the received address from uintptr_t to 'half *'
21012100
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
@@ -2118,15 +2117,13 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
21182117
/// \param [in] addr The address of the matrix in local memory
21192118
/// \param [in] m1 The private memory to store data of 1st matrix
21202119
/// \param [in] m2 The private memory to store data of 2nd matrix
2121-
/// \param [in] item The sycl::nd_item index space class
21222120
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2123-
template <typename T, typename ItemT>
2124-
void ldmatrix(uintptr_t addr, T *m1, T *m2, const ItemT &item,
2125-
bool trans = false) {
2121+
template <typename T>
2122+
void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21262123
// Load 1st matrix
2127-
ldmatrix(addr, m1, item, trans, 0);
2124+
ldmatrix(addr, m1, trans, 0);
21282125
// Load 2nd matrix
2129-
ldmatrix(addr, m2, item, trans, 1);
2126+
ldmatrix(addr, m2, trans, 1);
21302127
}
21312128

21322129
/// Loads 4 8x8 b16 matrix from local memory to private memory (32-bits per wi)
@@ -2137,19 +2134,17 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, const ItemT &item,
21372134
/// \param [in] m2 The private memory to store data of 2nd matrix
21382135
/// \param [in] m3 The private memory to store data of 3rd matrix
21392136
/// \param [in] m4 The private memory to store data of 4th matrix
2140-
/// \param [in] item The sycl::nd_item index space class
21412137
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2142-
template <typename T, typename ItemT>
2143-
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, const ItemT &item,
2144-
bool trans = false) {
2138+
template <typename T>
2139+
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
21452140
// Load 1st matrix
2146-
ldmatrix(addr, m1, item, trans, 0);
2141+
ldmatrix(addr, m1, trans, 0);
21472142
// Load 2nd matrix
2148-
ldmatrix(addr, m2, item, trans, 1);
2143+
ldmatrix(addr, m2, trans, 1);
21492144
// Load 3rd matrix
2150-
ldmatrix(addr, m3, item, trans, 2);
2145+
ldmatrix(addr, m3, trans, 2);
21512146
// Load 4th matrix
2152-
ldmatrix(addr, m4, item, trans, 3);
2147+
ldmatrix(addr, m4, trans, 3);
21532148
}
21542149

21552150
} // namespace matrix

clang/test/dpct/asm/ldmatrix.cu

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ __device__ void load_matrix_x1(void *sh_r_addr, int *r) {
2222
// CHECK: auto addr = sh_r_addr;
2323
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
2424

25-
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], item_ct1);
25+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0]);
2626
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
2727
: "=r"(r[0])
2828
: "r"(addr));
@@ -32,7 +32,7 @@ __device__ void load_matrix_x2(void *sh_r_addr, int *r) {
3232
// CHECK: auto addr = sh_r_addr;
3333
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
3434

35-
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], item_ct1);
35+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1]);
3636
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
3737
: "=r"(r[0]), "=r"(r[1])
3838
: "r"(addr));
@@ -42,7 +42,7 @@ __device__ void load_matrix_x4(void *sh_r_addr, int *r) {
4242
// CHECK: auto addr = sh_r_addr;
4343
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
4444

45-
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], item_ct1);
45+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3]);
4646
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
4747
: "=r"(r[0]), "=r"(r[1]), "=r"(r[2]), "=r"(r[3])
4848
: "r"(addr));
@@ -52,7 +52,7 @@ __device__ void load_matrix_x1_trans(void *sh_r_addr, int *r) {
5252
// CHECK: auto addr = sh_r_addr;
5353
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
5454

55-
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], item_ct1, true);
55+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], true);
5656
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
5757
: "=r"(r[0])
5858
: "r"(addr));
@@ -62,7 +62,7 @@ __device__ void load_matrix_x2_trans(void *sh_r_addr, int *r) {
6262
// CHECK: auto addr = sh_r_addr;
6363
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
6464

65-
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], item_ct1, true);
65+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], true);
6666
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
6767
: "=r"(r[0]), "=r"(r[1])
6868
: "r"(addr));
@@ -72,7 +72,7 @@ __device__ void load_matrix_x4_trans(void *sh_r_addr, int *r) {
7272
// CHECK: auto addr = sh_r_addr;
7373
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
7474

75-
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], item_ct1, true);
75+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], true);
7676
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
7777
: "=r"(r[0]), "=r"(r[1]), "=r"(r[2]), "=r"(r[3])
7878
: "r"(addr));

0 commit comments

Comments
 (0)