@@ -2056,8 +2056,24 @@ class joint_matrix {
2056
2056
const size_t num_elements;
2057
2057
};
2058
2058
2059
- // / Loads 1 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2060
- // / Requires the sub-group size of kernel calling this function to be 32
2059
+ // / Loads 1 8x8 b16 matrix from private memory to local memory per sub-group.
2060
+ // / Requires the sub-group size of kernel calling this function to be 32.
2061
+ // / Each of the first 8 work items contain the starting address of their
2062
+ // / respective matrix row.
2063
+ // / Each of the 32 work items load 32-bits (2 packed 16-bit data) for a total
2064
+ // / of 128 bytes.
2065
+ // / Row Major: Each row of the matrix is loaded by a group of 4 work items
2066
+ // / r0: t0 t1 t2 t3
2067
+ // / r1: t4 t5 t6 t7
2068
+ // / ...
2069
+ // / r7: t24 t25 t26 t27
2070
+ // / r7: t28 t29 t30 t31
2071
+ // / Col Major: Each col of the matrix is loadedd by a group of 4 work items
2072
+ // / r0: t0 t4 t8 ... t28
2073
+ // / r1: t0 t4 t8 ... t28
2074
+ // / ...
2075
+ // / r6: t3 t7 t11 ... t31
2076
+ // / r7: t3 t7 t11 ... t31
2061
2077
// / \tparam [in] T The type of result variable
2062
2078
// / \param [in] addr The address of the matrix in local memory
2063
2079
// / \param [in] m The private memory to store the matrix
@@ -2111,8 +2127,12 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
2111
2127
}
2112
2128
}
2113
2129
2114
- // / Loads 2 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2115
- // / Requires the sub-group size of kernel calling this function to be 32
2130
+ // / Loads 2 8x8 b16 matrix from private memory to local memory per sub-group.
2131
+ // / Requires the sub-group size of kernel calling this function to be 32.
2132
+ // / Each of the first 16 work items contain the starting address of their
2133
+ // / respective matrix row.
2134
+ // / Each of the 32 work items load 64-bits (32-bit per matrix) for a total
2135
+ // / of 256 bytes.
2116
2136
// / \tparam [in] T The type of result variable
2117
2137
// / \param [in] addr The address of the matrix in local memory
2118
2138
// / \param [in] m1 The private memory to store data of 1st matrix
@@ -2126,8 +2146,12 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
2126
2146
ldmatrix (addr, m2, trans, 1 );
2127
2147
}
2128
2148
2129
- // / Loads 4 8x8 b16 matrix from local memory to private memory (32-bits per wi)
2130
- // / Requires the sub-group size of kernel calling this function to be 32
2149
+ // / Loads 4 8x8 b16 matrix from private memory to local memory per sub-group.
2150
+ // / Requires the sub-group size of kernel calling this function to be 32.
2151
+ // / Each of the 32 work items contain the starting address of their
2152
+ // / respective matrix row.
2153
+ // / Each of the 32 work items load 128-bits (32-bit per matrix) for a total
2154
+ // / of 512 bytes.
2131
2155
// / \tparam [in] T The type of result variable
2132
2156
// / \param [in] addr The address of the matrix in local memory
2133
2157
// / \param [in] m1 The private memory to store data of 1st matrix
0 commit comments