@@ -2056,8 +2056,26 @@ class joint_matrix {
2056
2056
const size_t num_elements;
2057
2057
};
2058
2058
2059
- // / Loads 1 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2060
- // / Requires the sub-group size of kernel calling this function to be 32
2059
+ // / Loads 1 8x8 b16 (128 bytes) matrix from private memory to local memory per
2060
+ // / sub-group. Requires the sub-group size of kernel calling this function to
2061
+ // / be 32. 'mat' specifies the matrix index to be loaded. The first '(mat + 1) *
2062
+ // / 8' work items of sub-group contain the starting address of their respective
2063
+ // / matrix row in 'addr'. After distributing addresses to other work items, each
2064
+ // / of the 32 work items load 32-bits (2 packed 16-bit data) into 'm' for a
2065
+ // / total of 128 bytes. 'trans' specifies to perform a transposed/non-transposed
2066
+ // / load by each work item like below
2067
+ // / Row Major: Each row of the matrix is loaded by a group of 4 work items(wi)
2068
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2069
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2070
+ // / ...
2071
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2072
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2073
+ // / Col Major: Each col of the matrix is loaded by a group of 4 work items(wi)
2074
+ // / row-0: wi0 wi4 wi8 ... wi28
2075
+ // / row-1: wi0 wi4 wi8 ... wi28
2076
+ // / ...
2077
+ // / row-6: wi3 wi7 wi11 ... wi31
2078
+ // / row-7: wi3 wi7 wi11 ... wi31
2061
2079
// / \tparam [in] T The type of result variable
2062
2080
// / \param [in] addr The address of the matrix in local memory
2063
2081
// / \param [in] m The private memory to store the matrix
@@ -2111,8 +2129,25 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
2111
2129
}
2112
2130
}
2113
2131
2114
- // / Loads 2 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2115
- // / Requires the sub-group size of kernel calling this function to be 32
2132
+ // / Loads 2 8x8 b16 (256 bytes) matrix from private memory to local memory per
2133
+ // / sub-group. Requires the sub-group size of kernel calling this function to
2134
+ // / be 32. The first 16 work items of sub-group contain the starting address of
2135
+ // / their respective matrix row in 'addr'. After distributing addresses to other
2136
+ // / work items, each of the 32 work items load 64-bits (32-bits per matrix) into
2137
+ // / 'm1' & 'm2' for a total of 256 bytes. 'trans' specifies to perform a
2138
+ // / transposed/non-transposed load by each work item like below
2139
+ // / Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
2140
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2141
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2142
+ // / ...
2143
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2144
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2145
+ // / Col Major: Each col of the matrices is loaded by a group of 4 work items(wi)
2146
+ // / row-0: wi0 wi4 wi8 ... wi28
2147
+ // / row-1: wi0 wi4 wi8 ... wi28
2148
+ // / ...
2149
+ // / row-6: wi3 wi7 wi11 ... wi31
2150
+ // / row-7: wi3 wi7 wi11 ... wi31
2116
2151
// / \tparam [in] T The type of result variable
2117
2152
// / \param [in] addr The address of the matrix in local memory
2118
2153
// / \param [in] m1 The private memory to store data of 1st matrix
@@ -2126,8 +2161,26 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
2126
2161
ldmatrix (addr, m2, trans, 1 );
2127
2162
}
2128
2163
2129
- // / Loads 4 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2130
- // / Requires the sub-group size of kernel calling this function to be 32
2164
+ // / Loads 4 8x8 b16 (512 bytes) matrix from private memory to local memory per
2165
+ // / sub-group. Requires the sub-group size of kernel calling this function to
2166
+ // / be 32. Each work item of sub-group contains the starting address of their
2167
+ // / respective matrix row in 'addr'.
2168
+ // / After distributing addresses to other work items, each of the 32 work items
2169
+ // / load 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2170
+ // / of 512 bytes. 'trans' specifies to perform a transposed/non-transposed load
2171
+ // / by each work item like below
2172
+ // / Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
2173
+ // / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2174
+ // / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2175
+ // / ...
2176
+ // / row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2177
+ // / row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2178
+ // / Col Major: Each col of the matrices is loaded by a group of 4 work items(wi)
2179
+ // / row-0: wi0 wi4 wi8 ... wi28
2180
+ // / row-1: wi0 wi4 wi8 ... wi28
2181
+ // / ...
2182
+ // / row-6: wi3 wi7 wi11 ... wi31
2183
+ // / row-7: wi3 wi7 wi11 ... wi31
2131
2184
// / \tparam [in] T The type of result variable
2132
2185
// / \param [in] addr The address of the matrix in local memory
2133
2186
// / \param [in] m1 The private memory to store data of 1st matrix
0 commit comments