@@ -2061,13 +2061,12 @@ class joint_matrix {
2061
2061
// / \tparam [in] T The type of result variable
2062
2062
// / \param [in] addr The address of the matrix in local memory
2063
2063
// / \param [in] m The private memory to store the matrix
2064
- // / \param [in] item The sycl::nd_item index space class
2065
2064
// / \param [in] trans Indicates whether the matrix to be loaded transposed
2066
2065
// / \param [in] mat The matrix index to be loaded
2067
- template <typename T, typename ItemT >
2068
- void ldmatrix (uintptr_t addr, T *m, const ItemT &item, bool trans = false ,
2069
- unsigned mat = 0 ) {
2070
- int lane = item. get_sub_group () .get_local_linear_id ();
2066
+ template <typename T>
2067
+ void ldmatrix (uintptr_t addr, T *m, bool trans = false , unsigned mat = 0 ) {
2068
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2069
+ int lane = sg .get_local_linear_id ();
2071
2070
2072
2071
int lane_group8_row = lane / 8 ;
2073
2072
int lane_group8_col = lane % 8 ;
@@ -2079,8 +2078,8 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
2079
2078
src_lane += 1 ;
2080
2079
2081
2080
// Broadcast the address from the source lane
2082
- auto recv_addr_uintp = dpct::select_from_sub_group (
2083
- item. get_sub_group () , addr, mat * 8 + src_lane);
2081
+ auto recv_addr_uintp =
2082
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2084
2083
2085
2084
// Cast the received address from uintptr_t to the type of 'm'
2086
2085
auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2092,10 +2091,10 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
2092
2091
int src_lane = (lane % 4 ) * 2 ;
2093
2092
2094
2093
// Broadcast the address from the source lane
2095
- auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2096
- item. get_sub_group () , addr, mat * 8 + src_lane);
2097
- auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2098
- item. get_sub_group () , addr, mat * 8 + src_lane + 1 );
2094
+ auto recv_addr_uintp_1 =
2095
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2096
+ auto recv_addr_uintp_2 =
2097
+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane + 1 );
2099
2098
2100
2099
// Cast the received address from uintptr_t to 'half *'
2101
2100
auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2118,15 +2117,13 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
2118
2117
// / \param [in] addr The address of the matrix in local memory
2119
2118
// / \param [in] m1 The private memory to store data of 1st matrix
2120
2119
// / \param [in] m2 The private memory to store data of 2nd matrix
2121
- // / \param [in] item The sycl::nd_item index space class
2122
2120
// / \param [in] trans Indicates whether the matrix to be loaded transposed
2123
- template <typename T, typename ItemT>
2124
- void ldmatrix (uintptr_t addr, T *m1, T *m2, const ItemT &item,
2125
- bool trans = false ) {
2121
+ template <typename T>
2122
+ void ldmatrix (uintptr_t addr, T *m1, T *m2, bool trans = false ) {
2126
2123
// Load 1st matrix
2127
- ldmatrix (addr, m1, item, trans, 0 );
2124
+ ldmatrix (addr, m1, trans, 0 );
2128
2125
// Load 2nd matrix
2129
- ldmatrix (addr, m2, item, trans, 1 );
2126
+ ldmatrix (addr, m2, trans, 1 );
2130
2127
}
2131
2128
2132
2129
// / Loads 4 8x8 b16 matrix from local memory to private memory (32-bits per wi)
@@ -2137,19 +2134,17 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, const ItemT &item,
2137
2134
// / \param [in] m2 The private memory to store data of 2nd matrix
2138
2135
// / \param [in] m3 The private memory to store data of 3rd matrix
2139
2136
// / \param [in] m4 The private memory to store data of 4th matrix
2140
- // / \param [in] item The sycl::nd_item index space class
2141
2137
// / \param [in] trans Indicates whether the matrix to be loaded transposed
2142
- template <typename T, typename ItemT>
2143
- void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, const ItemT &item,
2144
- bool trans = false ) {
2138
+ template <typename T>
2139
+ void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false ) {
2145
2140
// Load 1st matrix
2146
- ldmatrix (addr, m1, item, trans, 0 );
2141
+ ldmatrix (addr, m1, trans, 0 );
2147
2142
// Load 2nd matrix
2148
- ldmatrix (addr, m2, item, trans, 1 );
2143
+ ldmatrix (addr, m2, trans, 1 );
2149
2144
// Load 3rd matrix
2150
- ldmatrix (addr, m3, item, trans, 2 );
2145
+ ldmatrix (addr, m3, trans, 2 );
2151
2146
// Load 4th matrix
2152
- ldmatrix (addr, m4, item, trans, 3 );
2147
+ ldmatrix (addr, m4, trans, 3 );
2153
2148
}
2154
2149
2155
2150
} // namespace matrix
0 commit comments