Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 8be0da2

Browse files
committed
init arch Xe2
1 parent 8f0abc4 commit 8be0da2

File tree

3 files changed

+56
-16
lines changed

3 files changed

+56
-16
lines changed

include/common/core/arch_config.hpp

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ struct load_store_attr_t {
3131
static constexpr bool has_hw_block_2d = false;
3232
};
3333

34-
template <>
35-
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
36-
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
34+
template <msg_type message_type, gpu_arch arg_tag>
35+
struct xe_plus_load_store_attr_t {
3736
static constexpr bool has_hw_block_2d = true;
3837
static constexpr uint32_t max_load_height_in_elem = 32;
3938
static constexpr uint32_t max_load_width_in_bytes = 64;
@@ -55,10 +54,9 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
5554

5655
template <msg_type message_type, gpu_arch arg_tag>
5756
struct client_load_store_attr_base_t {
58-
/// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
5957
static constexpr bool has_hw_block_2d = false;
60-
static constexpr uint32_t max_load_height_in_elem = 32;
61-
static constexpr uint32_t max_load_width_in_bytes = 64;
58+
static constexpr uint32_t max_load_height_in_elem = 0;
59+
static constexpr uint32_t max_load_width_in_bytes = 0;
6260
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
6361
static constexpr uint32_t max_vnni_load_width_in_elems = 16;
6462
static constexpr uint32_t min_vnni_load_height_in_bytes = 4;
@@ -87,6 +85,18 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeLpg>
8785
msg_type::block_2d,
8886
gpu_arch::XeLpg> {};
8987

88+
template <>
89+
struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc>
90+
: public client_load_store_attr_base_t<
91+
msg_type::block_2d,
92+
gpu_arch::XeHpc> {};
93+
94+
template <>
95+
struct load_store_attr_t<msg_type::block_2d, gpu_arch::Xe2>
96+
: public client_load_store_attr_base_t<
97+
msg_type::block_2d,
98+
gpu_arch::Xe2> {};
99+
90100
template <gpu_arch arch_tag>
91101
inline constexpr bool arch_has_2d_load_store =
92102
load_store_attr_t<msg_type::block_2d, arch_tag>::has_hw_block_2d;
@@ -105,6 +115,13 @@ struct load_store_attr_t<msg_type::block_1d, gpu_arch::XeHpc> {
105115
static constexpr uint32_t max_prefetch_vec_len = 64;
106116
};
107117

118+
template <>
119+
struct load_store_attr_t<msg_type::block_1d, gpu_arch::Xe2> {
120+
static constexpr uint32_t max_load_vec_len = 512;
121+
static constexpr uint32_t max_store_vec_len = 512;
122+
static constexpr uint32_t max_prefetch_vec_len = 64;
123+
};
124+
108125
struct dpas_attr_base_t {
109126
static constexpr bool has_xmx = true;
110127
static constexpr uint32_t systolic_depth = 8;
@@ -129,6 +146,11 @@ struct dpas_attr_t<gpu_arch::XeHpg> : public dpas_attr_base_t {
129146
static constexpr uint32_t n_fixed_limit = 8;
130147
};
131148

149+
template <>
150+
struct dpas_attr_t<gpu_arch::Xe2> : public dpas_attr_t<gpu_arch::XeHpc> {
151+
static constexpr uint32_t systolic_depth = 4;
152+
};
153+
132154
template <gpu_arch arch_tag>
133155
inline constexpr bool arch_has_xmx = dpas_attr_t<arch_tag>::has_xmx;
134156

@@ -162,6 +184,10 @@ template <>
162184
struct register_bytes_t<gpu_arch::XeLpg> {
163185
static constexpr uint32_t reg_in_bytes = 32;
164186
};
187+
template <>
188+
struct register_bytes_t<gpu_arch::Xe2> {
189+
static constexpr uint32_t reg_in_bytes = 64;
190+
};
165191

166192
template <grf_mode grf_num_mode, gpu_arch arch_tag>
167193
struct register_attr_t {
@@ -236,10 +262,25 @@ struct arch_attr_t<gpu_arch::XeLpg> {
236262

237263
using dpas_attr = dpas_attr_t<gpu_arch::XeLpg>;
238264

239-
static constexpr uint32_t max_wg_num = 64;
265+
static constexpr uint32_t max_wg_num = 16;
240266
static constexpr uint32_t local_mem_size = 64 * 1024;
241267
};
242268

269+
template <>
270+
struct arch_attr_t<gpu_arch::Xe2> {
271+
template <msg_type message_type = msg_type::block_2d>
272+
using load_store_attr = load_store_attr_t<message_type, gpu_arch::Xe2>;
273+
274+
template <grf_mode grf_num_mode = grf_mode::double_grf>
275+
using register_attr = register_attr_t<grf_num_mode, gpu_arch::Xe2>;
276+
277+
using dpas_attr = dpas_attr_t<gpu_arch::Xe2>;
278+
279+
static constexpr uint32_t max_wg_num = 16;
280+
static constexpr uint32_t local_mem_size = 128 * 1024;
281+
};
282+
283+
243284
/// @} xetla_core_arch_config
244285

245286
} // namespace gpu::xetla

include/common/core/common_types.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include <cstdint>
2222

2323
namespace gpu::xetla {
24-
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
24+
enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2, Xe2 = 3 };
2525

2626
enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2727

include/group/gemm/compute_policy.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,15 @@ struct compute_policy_default_fpu<
118118
static constexpr int sync_freq = perf_tuning_knob::sync_freq;
119119
static constexpr int k_stride = perf_tuning_knob::k_stride;
120120

121-
static constexpr uint32_t block_size_y_a =
122-
arch_tag_ == gpu_arch::XeLpg ? 8 : 16;
123-
static constexpr uint32_t block_bytes_x_a = 32;
121+
static constexpr uint32_t block_size_y_a = 16;
122+
using mma_attr = mma_attr_t<arch_tag_, block_size_y_a>;
123+
static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes;
124124
static constexpr uint32_t block_size_x_a =
125125
block_bytes_x_a / sizeof(dtype_mma_a);
126-
static constexpr uint32_t block_bytes_x_b =
127-
arch_attr_t<arch_tag>::template register_attr<>::reg_in_bytes;
128-
static constexpr uint32_t block_size_x_b =
129-
block_bytes_x_b / sizeof(dtype_mma_b);
130-
static constexpr uint32_t block_size_y_b = block_size_x_a;
126+
static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem;
127+
static constexpr uint32_t block_bytes_y_b = mma_attr::mma_k_in_bytes;
128+
static constexpr uint32_t block_size_y_b =
129+
block_bytes_y_b / sizeof(dtype_mma_b);
131130
};
132131

133132
/// @} xetla_gemm

0 commit comments

Comments
 (0)