@@ -2061,13 +2061,13 @@ class joint_matrix {
2061
2061
// / \tparam [in] T The type of result variable
2062
2062
// / \param [in] addr The address of the matrix in shared memory
2063
2063
// / \param [in] m The local memory to store the matrix
2064
- // / \param [in] item_ct1 The sycl::nd_item object
2064
+ // / \param [in] item The sycl::nd_item index space class
2065
2065
// / \param [in] trans Indicates whether the matrix to be loaded transposed
2066
2066
// / \param [in] mat The matrix index to be loaded
2067
- template <typename T>
2068
- void ldmatrix (uintptr_t addr, T *m, const sycl::nd_item< 3 > &item_ct1 ,
2069
- bool trans = false , unsigned mat = 0 ) {
2070
- int lane = item_ct1. get_local_id ( 2 ) % 32 ;
2067
+ template <typename T, typename ItemT >
2068
+ void ldmatrix (uintptr_t addr, T *m, const ItemT &item, bool trans = false ,
2069
+ unsigned mat = 0 ) {
2070
+ int lane = item. get_sub_group (). get_local_linear_id () ;
2071
2071
2072
2072
int lane_group8_row = lane / 8 ;
2073
2073
int lane_group8_col = lane % 8 ;
@@ -2080,7 +2080,7 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
2080
2080
2081
2081
// Broadcast the address from the source lane
2082
2082
auto recv_addr_uintp = dpct::select_from_sub_group (
2083
- item_ct1 .get_sub_group (), addr, mat * 8 + src_lane);
2083
+ item .get_sub_group (), addr, mat * 8 + src_lane);
2084
2084
2085
2085
// Cast the received address from uintptr_t to the type of 'm'
2086
2086
auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2093,9 +2093,9 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
2093
2093
2094
2094
// Broadcast the address from the source lane
2095
2095
auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2096
- item_ct1 .get_sub_group (), addr, mat * 8 + src_lane);
2096
+ item .get_sub_group (), addr, mat * 8 + src_lane);
2097
2097
auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2098
- item_ct1 .get_sub_group (), addr, mat * 8 + src_lane + 1 );
2098
+ item .get_sub_group (), addr, mat * 8 + src_lane + 1 );
2099
2099
2100
2100
// Cast the received address from uintptr_t to 'half *'
2101
2101
auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2118,15 +2118,15 @@ void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
2118
2118
// / \param [in] addr The address of the matrix in shared memory
2119
2119
// / \param [in] m1 The local memory to store data of 1st matrix
2120
2120
// / \param [in] m2 The local memory to store data of 2nd matrix
2121
- // / \param [in] item_ct1 The sycl::nd_item object
2121
+ // / \param [in] item The sycl::nd_item index space class
2122
2122
// / \param [in] trans Indicates whether the matrix to be loaded transposed
2123
- template <typename T>
2124
- void ldmatrix (uintptr_t addr, T *m1, T *m2, const sycl::nd_item< 3 > &item_ct1 ,
2123
+ template <typename T, typename ItemT >
2124
+ void ldmatrix (uintptr_t addr, T *m1, T *m2, const ItemT &item ,
2125
2125
bool trans = false ) {
2126
2126
// Load 1st matrix
2127
- ldmatrix (addr, m1, item_ct1 , trans, 0 );
2127
+ ldmatrix (addr, m1, item , trans, 0 );
2128
2128
// Load 2nd matrix
2129
- ldmatrix (addr, m2, item_ct1 , trans, 1 );
2129
+ ldmatrix (addr, m2, item , trans, 1 );
2130
2130
}
2131
2131
2132
2132
// / Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
@@ -2137,19 +2137,19 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
2137
2137
// / \param [in] m2 The local memory to store data of 2nd matrix
2138
2138
// / \param [in] m3 The local memory to store data of 3rd matrix
2139
2139
// / \param [in] m4 The local memory to store data of 4th matrix
2140
- // / \param [in] item_ct1 The sycl::nd_item object
2140
+ // / \param [in] item The sycl::nd_item index space class
2141
2141
// / \param [in] trans Indicates whether the matrix to be loaded transposed
2142
- template <typename T>
2143
- void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
2144
- const sycl::nd_item< 3 > &item_ct1, bool trans = false ) {
2142
+ template <typename T, typename ItemT >
2143
+ void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, const ItemT &item,
2144
+ bool trans = false ) {
2145
2145
// Load 1st matrix
2146
- ldmatrix (addr, m1, item_ct1 , trans, 0 );
2146
+ ldmatrix (addr, m1, item , trans, 0 );
2147
2147
// Load 2nd matrix
2148
- ldmatrix (addr, m2, item_ct1 , trans, 1 );
2148
+ ldmatrix (addr, m2, item , trans, 1 );
2149
2149
// Load 3rd matrix
2150
- ldmatrix (addr, m3, item_ct1 , trans, 2 );
2150
+ ldmatrix (addr, m3, item , trans, 2 );
2151
2151
// Load 4th matrix
2152
- ldmatrix (addr, m4, item_ct1 , trans, 3 );
2152
+ ldmatrix (addr, m4, item , trans, 3 );
2153
2153
}
2154
2154
2155
2155
} // namespace matrix
0 commit comments