@@ -2056,8 +2056,26 @@ class joint_matrix {
20562056 const size_t num_elements;
20572057};
20582058
2059- // / Loads 1 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2060- // / Requires the sub-group size of kernel calling this function to be 32
2059+ // / Loads 1 8x8 b16 (128 bytes) matrix from private memory to local memory per
2060+ // / sub-group. Requires the sub-group size of kernel calling this function to
2061+ // / be 32. 'mat' specifies the matrix index to be loaded. The first '(mat + 1) *
2062+ // / 8' work items of sub-group contain the starting address of their respective
2063+ // / matrix row in 'addr'. After distributing addresses to other work items, each
2064+ // / of the 32 work items load 32-bits (2 packed 16-bit data) into 'm' for a
2065+ // / total of 128 bytes. 'trans' specifies to perform a transposed/non-transposed
2066+ // / load by each work item like below
2067+ // / Row Major: Each row of the matrix is loaded by a group of 4 work items(wi)
2068+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2069+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2070+ // / ...
2071+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2072+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2073+ // / Col Major: Each col of the matrix is loaded by a group of 4 work items(wi)
2074+ // / row-0: wi0 wi4 wi8 ... wi28
2075+ // / row-1: wi0 wi4 wi8 ... wi28
2076+ // / ...
2077+ // / row-6: wi3 wi7 wi11 ... wi31
2078+ // / row-7: wi3 wi7 wi11 ... wi31
20612079// / \tparam [in] T The type of result variable
20622080// / \param [in] addr The address of the matrix in local memory
20632081// / \param [in] m The private memory to store the matrix
@@ -2111,8 +2129,25 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21112129 }
21122130}
21132131
2114- // / Loads 2 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2115- // / Requires the sub-group size of kernel calling this function to be 32
2132+ // / Loads 2 8x8 b16 (256 bytes) matrix from private memory to local memory per
2133+ // / sub-group. Requires the sub-group size of kernel calling this function to
2134+ // / be 32. The first 16 work items of sub-group contain the starting address of
2135+ // / their respective matrix row in 'addr'. After distributing addresses to other
2136+ // / work items, each of the 32 work items load 64-bits (32-bits per matrix) into
2137+ // / 'm1' & 'm2' for a total of 256 bytes. 'trans' specifies to perform a
2138+ // / transposed/non-transposed load by each work item like below
2139+ // / Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
2140+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2141+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2142+ // / ...
2143+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2144+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2145+ // / Col Major: Each col of the matrices is loaded by a group of 4 work items(wi)
2146+ // / row-0: wi0 wi4 wi8 ... wi28
2147+ // / row-1: wi0 wi4 wi8 ... wi28
2148+ // / ...
2149+ // / row-6: wi3 wi7 wi11 ... wi31
2150+ // / row-7: wi3 wi7 wi11 ... wi31
21162151// / \tparam [in] T The type of result variable
21172152// / \param [in] addr The address of the matrix in local memory
21182153// / \param [in] m1 The private memory to store data of 1st matrix
@@ -2126,8 +2161,26 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21262161 ldmatrix (addr, m2, trans, 1 );
21272162}
21282163
2129- // / Loads 4 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2130- // / Requires the sub-group size of kernel calling this function to be 32
2164+ // / Loads 4 8x8 b16 (512 bytes) matrix from private memory to local memory per
2165+ // / sub-group. Requires the sub-group size of kernel calling this function to
2166+ // / be 32. Each work item of sub-group contains the starting address of their
2167+ // / respective matrix row in 'addr'.
2168+ // / After distributing addresses to other work items, each of the 32 work items
2169+ // / load 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2170+ // / of 512 bytes. 'trans' specifies to perform a transposed/non-transposed load
2171+ // / by each work item like below
2172+ // / Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
2173+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2174+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2175+ // / ...
2176+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2177+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2178+ // / Col Major: Each col of the matrices is loaded by a group of 4 work items(wi)
2179+ // / row-0: wi0 wi4 wi8 ... wi28
2180+ // / row-1: wi0 wi4 wi8 ... wi28
2181+ // / ...
2182+ // / row-6: wi3 wi7 wi11 ... wi31
2183+ // / row-7: wi3 wi7 wi11 ... wi31
21312184// / \tparam [in] T The type of result variable
21322185// / \param [in] addr The address of the matrix in local memory
21332186// / \param [in] m1 The private memory to store data of 1st matrix
0 commit comments