Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 90662a8

Browse files
committed
bugfix for update offset_x/offset_y
1 parent 4f481df commit 90662a8

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ tile_load(tile_t& tile, payload_t& payload) {
9393
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
9494

9595
static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
96+
// In the case of pack, tranpose is in vnni format
9697
static constexpr bool is_vnni_reverse =
9798
payload_t::mem_transpose_dtype_less4bytes &&
9899
((reg_layout_ == reg_layout::tiled) ||
@@ -188,14 +189,13 @@ tile_load(tile_t& tile, payload_t& payload) {
188189
((block_size_y * sizeof(dtype)) % sizeof(load_dtype) == 0),
189190
"check vnni limitation for DW transpose");
190191

191-
// auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
192192
#pragma unroll
193193
for (uint32_t i = 0; i < num_block_y; ++i) {
194-
constexpr uint32_t load_block_elems = block_elems * arr_len;
195194
int offset_y = i * block_size_y;
196195
#pragma unroll
197196
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
198197
int32_t offset_x = j * block_size_x;
198+
constexpr uint32_t load_block_elems = block_elems * arr_len;
199199
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
200200
(i * num_block_x + j) * block_elems);
201201
constexpr uint32_t ld_blk_height = (reg_transpose && trans)

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,14 @@ struct mem_payload_t<
230230
__XETLA_API void update_tdesc(int offset) {
231231
auto payloads_2d = payloads.xetla_format<uint32_t, num_block, 16>();
232232
if constexpr (update_dir == tdesc_update_dir::x_dir) {
233+
offset_x += offset / scale_factor;
233234
#pragma unroll
234235
for (uint32_t i = 0; i < num_block; i++) {
235236
xetla_update_tdesc_offsetx(
236237
payloads_2d.row(i), offset / int32_t(scale_factor));
237238
}
238239
} else {
240+
offset_y += offset;
239241
#pragma unroll
240242
for (uint32_t i = 0; i < num_block; i++) {
241243
xetla_update_tdesc_offsety(payloads_2d.row(i), offset);

0 commit comments

Comments
 (0)