Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 468c1ea

Browse files
committed
opt store_xe.hpp
1 parent bba4180 commit 468c1ea

File tree

4 files changed

+90
-145
lines changed

4 files changed

+90
-145
lines changed

include/common/core/arch_config.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
4747
// BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords.
4848
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
4949

50+
// BlockHeight must be 8 for qwords and be in range [1..32] for dwords.
51+
static constexpr uint32_t max_trans_load_height_in_elem = 32;
52+
5053
// If Transformed is true
5154
// BlockWidth must be in range [4..16] for bytes and [2..16] for word.
5255
static constexpr uint32_t max_vnni_load_width_in_elems = 16;

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ tile_load(tile_t& tile, payload_t& payload) {
8989

9090
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
9191
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
92-
// static constexpr uint32_t num_block = tile_desc::num_block;
9392

9493
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
9594

@@ -181,19 +180,9 @@ tile_load(tile_t& tile, payload_t& payload) {
181180
for (uint32_t i = 0; i < num_block_y; ++i) {
182181
constexpr uint32_t load_block_elems = block_elems * arr_len;
183182
int offset_y = i * block_size_y;
184-
// auto payload_row =
185-
// payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
186-
// detail::reset_tile_desc_core<
187-
// num_block_x,
188-
// block_size_x,
189-
// ld_blk_size_y,
190-
// scale_factor,
191-
// arr_len,
192-
// mem_transpose>(payload_row);
193183
#pragma unroll
194184
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
195185
int32_t offset_x = j * block_size_x;
196-
// xetla_tdescriptor tdesc = payload_row.row(j);
197186
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
198187
(i * num_block_x + j) * block_elems);
199188
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
@@ -215,7 +204,8 @@ tile_load(tile_t& tile, payload_t& payload) {
215204
mem_transform,
216205
L1,
217206
L2>(
218-
payload.base_ptr,
207+
reinterpret_cast<const native_type_t<load_dtype*>>(
208+
payload.base_ptr),
219209
payload.surface_width,
220210
payload.surface_height,
221211
payload.surface_pitch,
@@ -273,7 +263,8 @@ tile_load(tile_t& tile, payload_t& payload) {
273263
mem_transform,
274264
L1,
275265
L2>(
276-
payload.base_ptr,
266+
reinterpret_cast<const native_type_t<load_dtype*>>(
267+
payload.base_ptr),
277268
payload.surface_width,
278269
payload.surface_height,
279270
payload.surface_pitch,
@@ -335,7 +326,8 @@ tile_load(tile_t& tile, payload_t& payload) {
335326
mem_transform,
336327
L1,
337328
L2>(
338-
payload.base_ptr,
329+
reinterpret_cast<const native_type_t<load_dtype*>>(
330+
payload.base_ptr),
339331
payload.surface_width,
340332
payload.surface_height,
341333
payload.surface_pitch,
@@ -402,7 +394,8 @@ tile_load(tile_t& tile, payload_t& payload) {
402394
mem_transform,
403395
L1,
404396
L2>(
405-
payload.base_ptr,
397+
reinterpret_cast<const native_type_t<load_dtype*>>(
398+
payload.base_ptr),
406399
payload.surface_width,
407400
payload.surface_height,
408401
payload.surface_pitch,

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,28 @@ struct mem_payload_t<
8484
using mem_dtype = typename std::
8585
conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
8686
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
87-
mem_dtype* base_ptr;
87+
88+
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
89+
90+
static constexpr uint32_t max_load_width_in_elem = trans
91+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
92+
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
93+
static constexpr uint32_t max_load_height_in_elem = trans
94+
? load_store_attr::max_trans_load_height_in_elem
95+
: load_store_attr::max_load_height_in_elem;
96+
97+
static constexpr uint32_t max_store_width_in_elem =
98+
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
99+
static constexpr uint32_t max_store_height_in_elem =
100+
load_store_attr::max_store_height_in_elem;
101+
102+
static constexpr uint32_t elems_per_CL =
103+
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
104+
105+
static constexpr uint32_t elems_per_reg =
106+
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
107+
108+
dtype* base_ptr;
88109
uint32_t surface_width;
89110
uint32_t surface_height;
90111
uint32_t surface_pitch;
@@ -105,7 +126,7 @@ struct mem_payload_t<
105126
}
106127

107128
inline mem_payload_t(mem_desc_t& mem_desc) {
108-
this->base_ptr = (mem_dtype*)mem_desc.base.base;
129+
this->base_ptr = (dtype*)mem_desc.base.base;
109130
this->surface_width =
110131
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
111132
this->surface_height =
@@ -130,7 +151,7 @@ struct mem_payload_t<
130151
uint32_t surface_pitch,
131152
int32_t surface_offset_x = 0,
132153
int32_t surface_offset_y = 0) {
133-
this->base_ptr = (mem_dtype*)p;
154+
this->base_ptr = p;
134155
this->surface_width = surface_width * sizeof(dtype);
135156
this->surface_height = surface_height;
136157
this->surface_pitch = surface_pitch * sizeof(dtype);
@@ -151,7 +172,7 @@ struct mem_payload_t<
151172
}
152173

153174
__XETLA_API void init(mem_desc_t& mem_desc) {
154-
this->base_ptr = (mem_dtype*)mem_desc.base.base;
175+
this->base_ptr = (dtype*)mem_desc.base.base;
155176
this->surface_width =
156177
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
157178
this->surface_height =
@@ -184,7 +205,7 @@ struct mem_payload_t<
184205
uint32_t surface_pitch,
185206
int32_t surface_offset_x = 0,
186207
int32_t surface_offset_y = 0) {
187-
this->base_ptr = (mem_dtype*)p;
208+
this->base_ptr = p;
188209
this->surface_width = surface_width * sizeof(dtype);
189210
this->surface_height = surface_height;
190211
this->surface_pitch = surface_pitch * sizeof(dtype);

0 commit comments

Comments
 (0)