Skip to content

Commit c52860e

Browse files
Added comments for the helper functions
1 parent d4a821a commit c52860e

File tree

1 file changed

+47
-8
lines changed
  • clang/runtime/dpct-rt/include/dpct

1 file changed

+47
-8
lines changed

clang/runtime/dpct-rt/include/dpct/math.hpp

+47-8
Original file line numberDiff line numberDiff line change
@@ -2056,60 +2056,99 @@ class joint_matrix {
20562056
const size_t num_elements;
20572057
};
20582058

2059+
/// Loads 1 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2060+
/// Requires the sub-group size of kernel calling this function to be 32
2061+
/// \tparam [in] T The type of result variable
2062+
/// \param [in] addr The address of the matrix in shared memory
2063+
/// \param [in] m The local memory to store the matrix
2064+
/// \param [in] item_ct1 The sycl::nd_item object
2065+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2066+
/// \param [in] mat The matrix index to be loaded
20592067
template <typename T>
20602068
void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20612069
bool trans = false, unsigned mat = 0) {
2062-
int lane = item_ct1.get_local_id(2);
2070+
int lane = item_ct1.get_local_id(2) % 32;
20632071

2064-
int group = lane / 8;
2065-
int sub = lane % 8;
2066-
int src_base = group * 2;
2072+
int lane_group8_row = lane / 8;
2073+
int lane_group8_col = lane % 8;
20672074

20682075
if (!trans) {
20692076
// calculate the source lane
2070-
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2077+
int src_lane = 2 * lane_group8_row;
2078+
if (lane_group8_col >= 4)
2079+
src_lane += 1;
20712080

20722081
// Broadcast the address from the source lane
20732082
auto recv_addr_uintp = dpct::select_from_sub_group(
20742083
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2084+
2085+
// Cast the received address from uintptr_t to the type of 'm'
20752086
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
20762087

20772088
// Non-transposed load
2078-
*m = recv_addr[sub % 4];
2089+
*m = recv_addr[lane_group8_col % 4];
20792090
} else {
20802091
// calculate the source lane
20812092
int src_lane = (lane % 4) * 2;
20822093

2083-
// Broadcast the address from the source lane:
2094+
// Broadcast the address from the source lane
20842095
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
20852096
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
20862097
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
20872098
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2099+
2100+
// Cast the received address from uintptr_t to 'half *'
20882101
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
20892102
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
20902103

20912104
// Transposed load
2092-
int index = (lane / 4);
2105+
int index = lane / 4;
20932106
sycl::half val0 = recv_addr_1[index];
20942107
sycl::half val1 = recv_addr_2[index];
2108+
2109+
// Combine the two 16-bits into one 32-bit value
20952110
sycl::half2 val = sycl::half2(val0, val1);
20962111
*m = *reinterpret_cast<T *>(&val);
20972112
}
20982113
}
20992114

2115+
/// Loads 2 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2116+
/// Requires the sub-group size of kernel calling this function to be 32
2117+
/// \tparam [in] T The type of result variable
2118+
/// \param [in] addr The address of the matrix in shared memory
2119+
/// \param [in] m1 The local memory to store data of 1st matrix
2120+
/// \param [in] m2 The local memory to store data of 2nd matrix
2121+
/// \param [in] item_ct1 The sycl::nd_item object
2122+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
21002123
template <typename T>
21012124
void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
21022125
bool trans = false) {
2126+
// Load 1st matrix
21032127
ldmatrix(addr, m1, item_ct1, trans, 0);
2128+
// Load 2nd matrix
21042129
ldmatrix(addr, m2, item_ct1, trans, 1);
21052130
}
21062131

2132+
/// Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2133+
/// Requires the sub-group size of kernel calling this function to be 32
2134+
/// \tparam [in] T The type of result variable
2135+
/// \param [in] addr The address of the matrix in shared memory
2136+
/// \param [in] m1 The local memory to store data of 1st matrix
2137+
/// \param [in] m2 The local memory to store data of 2nd matrix
2138+
/// \param [in] m3 The local memory to store data of 3rd matrix
2139+
/// \param [in] m4 The local memory to store data of 4th matrix
2140+
/// \param [in] item_ct1 The sycl::nd_item object
2141+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
21072142
template <typename T>
21082143
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
21092144
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2145+
// Load 1st matrix
21102146
ldmatrix(addr, m1, item_ct1, trans, 0);
2147+
// Load 2nd matrix
21112148
ldmatrix(addr, m2, item_ct1, trans, 1);
2149+
// Load 3rd matrix
21122150
ldmatrix(addr, m3, item_ct1, trans, 2);
2151+
// Load 4th matrix
21132152
ldmatrix(addr, m4, item_ct1, trans, 3);
21142153
}
21152154

0 commit comments

Comments
 (0)