@@ -2056,60 +2056,99 @@ class joint_matrix {
2056
2056
const size_t num_elements;
2057
2057
};
2058
2058
2059
+ // / Loads 1 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2060
+ // / Requires the sub-group size of kernel calling this function to be 32
2061
+ // / \tparam [in] T The type of result variable
2062
+ // / \param [in] addr The address of the matrix in shared memory
2063
+ // / \param [in] m The local memory to store the matrix
2064
+ // / \param [in] item_ct1 The sycl::nd_item object
2065
+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
2066
+ // / \param [in] mat The matrix index to be loaded
2059
2067
template <typename T>
2060
2068
void ldmatrix (uintptr_t addr, T *m, const sycl::nd_item<3 > &item_ct1,
2061
2069
bool trans = false , unsigned mat = 0 ) {
2062
- int lane = item_ct1.get_local_id (2 );
2070
+ int lane = item_ct1.get_local_id (2 ) % 32 ;
2063
2071
2064
- int group = lane / 8 ;
2065
- int sub = lane % 8 ;
2066
- int src_base = group * 2 ;
2072
+ int lane_group8_row = lane / 8 ;
2073
+ int lane_group8_col = lane % 8 ;
2067
2074
2068
2075
if (!trans) {
2069
2076
// calculate the source lane
2070
- int src_lane = (sub / 4 ) ? (src_base + 1 ) : src_base;
2077
+ int src_lane = 2 * lane_group8_row;
2078
+ if (lane_group8_col >= 4 )
2079
+ src_lane += 1 ;
2071
2080
2072
2081
// Broadcast the address from the source lane
2073
2082
auto recv_addr_uintp = dpct::select_from_sub_group (
2074
2083
item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
2084
+
2085
+ // Cast the received address from uintptr_t to the type of 'm'
2075
2086
auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
2076
2087
2077
2088
// Non-transposed load
2078
- *m = recv_addr[sub % 4 ];
2089
+ *m = recv_addr[lane_group8_col % 4 ];
2079
2090
} else {
2080
2091
// calculate the source lane
2081
2092
int src_lane = (lane % 4 ) * 2 ;
2082
2093
2083
- // Broadcast the address from the source lane:
2094
+ // Broadcast the address from the source lane
2084
2095
auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2085
2096
item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
2086
2097
auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2087
2098
item_ct1.get_sub_group (), addr, mat * 8 + src_lane + 1 );
2099
+
2100
+ // Cast the received address from uintptr_t to 'half *'
2088
2101
auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
2089
2102
auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
2090
2103
2091
2104
// Transposed load
2092
- int index = ( lane / 4 ) ;
2105
+ int index = lane / 4 ;
2093
2106
sycl::half val0 = recv_addr_1[index ];
2094
2107
sycl::half val1 = recv_addr_2[index ];
2108
+
2109
+ // Combine the two 16-bits into one 32-bit value
2095
2110
sycl::half2 val = sycl::half2 (val0, val1);
2096
2111
*m = *reinterpret_cast <T *>(&val);
2097
2112
}
2098
2113
}
2099
2114
2115
+ // / Loads 2 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2116
+ // / Requires the sub-group size of kernel calling this function to be 32
2117
+ // / \tparam [in] T The type of result variable
2118
+ // / \param [in] addr The address of the matrix in shared memory
2119
+ // / \param [in] m1 The local memory to store data of 1st matrix
2120
+ // / \param [in] m2 The local memory to store data of 2nd matrix
2121
+ // / \param [in] item_ct1 The sycl::nd_item object
2122
+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
2100
2123
template <typename T>
2101
2124
void ldmatrix (uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3 > &item_ct1,
2102
2125
bool trans = false ) {
2126
+ // Load 1st matrix
2103
2127
ldmatrix (addr, m1, item_ct1, trans, 0 );
2128
+ // Load 2nd matrix
2104
2129
ldmatrix (addr, m2, item_ct1, trans, 1 );
2105
2130
}
2106
2131
2132
+ // / Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2133
+ // / Requires the sub-group size of kernel calling this function to be 32
2134
+ // / \tparam [in] T The type of result variable
2135
+ // / \param [in] addr The address of the matrix in shared memory
2136
+ // / \param [in] m1 The local memory to store data of 1st matrix
2137
+ // / \param [in] m2 The local memory to store data of 2nd matrix
2138
+ // / \param [in] m3 The local memory to store data of 3rd matrix
2139
+ // / \param [in] m4 The local memory to store data of 4th matrix
2140
+ // / \param [in] item_ct1 The sycl::nd_item object
2141
+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
2107
2142
template <typename T>
2108
2143
void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
2109
2144
const sycl::nd_item<3 > &item_ct1, bool trans = false ) {
2145
+ // Load 1st matrix
2110
2146
ldmatrix (addr, m1, item_ct1, trans, 0 );
2147
+ // Load 2nd matrix
2111
2148
ldmatrix (addr, m2, item_ct1, trans, 1 );
2149
+ // Load 3rd matrix
2112
2150
ldmatrix (addr, m3, item_ct1, trans, 2 );
2151
+ // Load 4th matrix
2113
2152
ldmatrix (addr, m4, item_ct1, trans, 3 );
2114
2153
}
2115
2154
0 commit comments