Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 473707e

Browse files
sunjiweiswiftDDEle
authored andcommitted
update prefetch
1 parent 44ddaeb commit 473707e

File tree

5 files changed

+423
-406
lines changed

5 files changed

+423
-406
lines changed

include/common/core/arch_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ struct load_store_attr_t<msg_type::block_1d, arch_tag> {
119119
static constexpr uint32_t max_aligned_load_vec_len = 256;
120120
static constexpr uint32_t max_store_vec_len = 256;
121121
static constexpr uint32_t max_aligned_store_vec_len = 256;
122-
static constexpr uint32_t max_prefetch_vec_len = 32;
122+
static constexpr uint32_t max_prefetch_vec_len = 256;
123123
static constexpr uint32_t max_channel_num = 16;
124124
};
125125

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ tile_load(tile_t& tile, payload_t& payload) {
214214
// arch_tag>(tdesc);
215215
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
216216
native_type_t<load_dtype>,
217-
block_size_x / scale_factor,
218-
block_size_y,
217+
(mem_transpose ? ld_blk_size_y : block_size_x) / scale_factor,
218+
(mem_transpose ? block_size_x : ld_blk_size_y),
219+
// block_size_x / scale_factor,
220+
// ld_blk_size_y,
219221
arr_len,
220222
trans,
221223
mem_transform,

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 84 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,10 +1170,9 @@ struct mem_payload_t<
11701170
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
11711171
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
11721172
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1173-
static constexpr uint32_t tile_bytes =
1174-
tile_size_x * tile_size_y * sizeof(dtype);
1173+
static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof(dtype);
11751174
static constexpr uint32_t block_bytes =
1176-
block_size_x * block_size_y * sizeof(dtype);
1175+
tile_desc::block_elems * sizeof(dtype);
11771176
using this_payload_t =
11781177
mem_payload_t<mem_desc_t, tile_desc, msg_type::block_2d, arch_tag_>;
11791178

@@ -1250,7 +1249,7 @@ struct mem_payload_t<
12501249
base_offset = mem_transpose
12511250
? base_x * pitch_in_bytes + base_y * sizeof(dtype)
12521251
: base_y * pitch_in_bytes + base_x * sizeof(dtype);
1253-
base_ptr = (mem_dtype*)mem_tdesc.base.base;
1252+
base_ptr = reinterpret_cast<mem_dtype*>(mem_tdesc.base.base);
12541253

12551254
xetla_vector<uint32_t, num_channel> channel_index =
12561255
xetla_vector_gen<uint32_t, num_channel>(0, 1);
@@ -1734,10 +1733,8 @@ struct prefetch_payload_t<
17341733
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
17351734
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
17361735
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1737-
static constexpr uint32_t tile_bytes =
1738-
tile_size_x * tile_size_y * sizeof(dtype);
1739-
static constexpr uint32_t block_bytes =
1740-
block_size_x * block_size_y * sizeof(dtype);
1736+
static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof(dtype);
1737+
static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof(dtype);
17411738

17421739
private:
17431740
using this_payload_t =
@@ -1751,67 +1748,75 @@ struct prefetch_payload_t<
17511748
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
17521749
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
17531750

1754-
using prefetch_dtype = typename std::conditional<
1751+
using prefetch_dtype = typename std::conditional_t<
17551752
(alignment_in_bytes % (sizeof(uint64_t)) == 0),
17561753
uint64_t,
1757-
typename std::conditional<
1754+
typename std::conditional_t<
17581755
(alignment_in_bytes % (sizeof(uint32_t)) == 0),
17591756
uint32_t,
1760-
dtype>::type>::type;
1757+
dtype>>;
17611758
static constexpr uint32_t pack_factor =
17621759
sizeof(prefetch_dtype) / sizeof(dtype);
17631760

1764-
static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
1765-
static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
1766-
static constexpr uint32_t simd_channel =
1767-
((tile_bytes % max_store_bytes) == 0 &&
1768-
(block_bytes % max_store_bytes) == 0)
1769-
? 32
1770-
: 16;
1771-
static constexpr uint32_t num_channel = mem_transpose
1772-
? (simd_channel >= block_size_x) ? block_size_x : simd_channel
1773-
: (simd_channel >= block_size_y) ? block_size_y
1774-
: simd_channel;
1761+
static constexpr uint32_t vector_size =
1762+
((mem_transpose ? block_size_y : block_size_x) + pack_factor - 1) /
1763+
pack_factor;
17751764

1776-
static constexpr uint32_t vector_size = mem_transpose
1777-
? (block_size_y + pack_factor - 1) / pack_factor
1778-
: (block_size_x + pack_factor - 1) / pack_factor;
1765+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
1766+
static constexpr uint32_t max_prefetch_vec_len =
1767+
load_store_attr::max_prefetch_vec_len;
17791768

1780-
static constexpr uint32_t mem_tile_size_w =
1781-
mem_transpose ? tile_size_y : tile_size_x;
1782-
static constexpr uint32_t mem_tile_size_h =
1783-
mem_transpose ? tile_size_x : tile_size_y;
1784-
using load_store_attr =
1785-
typename arch_attr_t<arch_tag>::template load_store_attr<message_type>;
1786-
static constexpr uint32_t special_prefetch_width =
1787-
load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1788-
static constexpr uint32_t normal_prefetch_width =
1789-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1790-
static constexpr bool is_special_prefetch =
1791-
(mem_tile_size_w % special_prefetch_width) == 0;
1769+
static constexpr uint32_t max_channel =
1770+
max_prefetch_vec_len / (vector_size * sizeof(prefetch_dtype));
17921771

1793-
static constexpr uint32_t block_size_w = is_special_prefetch
1794-
? special_prefetch_width
1795-
: (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1796-
: normal_prefetch_width);
1797-
static constexpr uint32_t block_size_h =
1798-
load_store_attr::max_load_height_in_elem;
1799-
// could have over-prefetch, but that's should be fine
1800-
static constexpr uint32_t max_num_block_w =
1801-
(mem_tile_size_w + block_size_w - 1) / block_size_w;
1802-
static constexpr uint32_t num_coop_sg = num_coop_sg_;
1803-
static constexpr uint32_t num_coop_sg_w =
1804-
detail::gcd<num_coop_sg, max_num_block_w>::value;
1805-
static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1772+
static constexpr uint32_t select_channel(const uint32_t channel) {
1773+
return (channel >= load_store_attr::max_channel_num)
1774+
? load_store_attr::max_channel_num
1775+
: channel >= 16 ? 16
1776+
: channel >= 8 ? 8
1777+
: 1;
1778+
}
18061779

1807-
static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1808-
static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1809-
static constexpr uint32_t tile_size_h =
1810-
(mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1811-
static constexpr uint32_t num_block_h =
1812-
(tile_size_h + block_size_h - 1) / block_size_h;
1780+
static constexpr uint32_t num_channel = select_channel(
1781+
std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
1782+
1783+
// static constexpr uint32_t mem_tile_size_w =
1784+
// mem_transpose ? tile_size_y : tile_size_x;
1785+
// static constexpr uint32_t mem_tile_size_h =
1786+
// mem_transpose ? tile_size_x : tile_size_y;
1787+
1788+
// static constexpr uint32_t special_prefetch_width =
1789+
// load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1790+
// static constexpr uint32_t normal_prefetch_width =
1791+
// load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1792+
// static constexpr bool is_special_prefetch =
1793+
// (mem_tile_size_w % special_prefetch_width) == 0;
1794+
1795+
// static constexpr uint32_t block_size_w = is_special_prefetch
1796+
// ? special_prefetch_width
1797+
// : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1798+
// : normal_prefetch_width);
1799+
// static constexpr uint32_t block_size_h =
1800+
// load_store_attr::max_load_height_in_elem;
1801+
// // could have over-prefetch, but that's should be fine
1802+
// static constexpr uint32_t max_num_block_w =
1803+
// (mem_tile_size_w + block_size_w - 1) / block_size_w;
1804+
// static constexpr uint32_t num_coop_sg = num_coop_sg_;
1805+
// static constexpr uint32_t num_coop_sg_w =
1806+
// detail::gcd<num_coop_sg, max_num_block_w>::value;
1807+
// static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1808+
1809+
// static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1810+
// static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1811+
// static constexpr uint32_t tile_size_h =
1812+
// (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1813+
// static constexpr uint32_t num_block_h =
1814+
// (tile_size_h + block_size_h - 1) / block_size_h;
18131815

18141816
xetla_vector<uint32_t, num_channel> channel_offset;
1817+
xetla_vector<uint32_t, num_channel> step_x;
1818+
xetla_vector<uint32_t, num_channel> step_y;
1819+
18151820
uint64_t base_offset;
18161821
uint32_t base_x;
18171822
uint32_t base_y;
@@ -1848,13 +1853,15 @@ struct prefetch_payload_t<
18481853
return *this;
18491854
}
18501855

1851-
inline prefetch_payload_t(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
1852-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1853-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1856+
inline prefetch_payload_t(
1857+
mem_desc_t& mem_desc,
1858+
[[maybe_unused]] uint32_t coop_id = 0) {
1859+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1860+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
18541861

18551862
pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1856-
base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1857-
base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1863+
base_x = mem_desc.coord.x;
1864+
base_y = mem_desc.coord.y;
18581865
width_in_elems = mem_desc.shape.x;
18591866
height_in_elems = mem_desc.shape.y;
18601867
base_offset = mem_transpose
@@ -1874,13 +1881,15 @@ struct prefetch_payload_t<
18741881
int surface_pitch,
18751882
int surface_offset_x,
18761883
int surface_offset_y,
1877-
uint32_t coop_id = 0) {
1878-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1879-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1884+
[[maybe_unused]] uint32_t coop_id = 0) {
1885+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1886+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
1887+
// base_x = surface_offset_x + coop_id_x * tile_size_w;
1888+
// base_y = surface_offset_y + coop_id_y * tile_size_h;
18801889

18811890
pitch_in_bytes = surface_pitch * sizeof(dtype);
1882-
base_x = surface_offset_x + coop_id_x * tile_size_w;
1883-
base_y = surface_offset_y + coop_id_y * tile_size_h;
1891+
base_x = surface_offset_x;
1892+
base_y = surface_offset_y;
18841893
width_in_elems = surface_width;
18851894
height_in_elems = surface_height;
18861895
base_offset = mem_transpose
@@ -1893,13 +1902,17 @@ struct prefetch_payload_t<
18931902
channel_offset = channel_index * pitch_in_bytes;
18941903
}
18951904

1896-
inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
1897-
uint32_t coop_id_x = coop_id % num_coop_sg_w;
1898-
uint32_t coop_id_y = coop_id / num_coop_sg_w;
1905+
inline void init(
1906+
mem_desc_t& mem_desc,
1907+
[[maybe_unused]] uint32_t coop_id = 0) {
1908+
// uint32_t coop_id_x = coop_id % num_coop_sg_w;
1909+
// uint32_t coop_id_y = coop_id / num_coop_sg_w;
1910+
// base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1911+
// base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
18991912

19001913
pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
1901-
base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1902-
base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1914+
base_x = mem_desc.coord.x;
1915+
base_y = mem_desc.coord.y;
19031916
width_in_elems = mem_desc.shape.x;
19041917
height_in_elems = mem_desc.shape.y;
19051918
base_offset = mem_transpose

include/subgroup/tile/impl/prefetch_xe.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ tile_prefetch(payload_t& payload) {
104104
using prefetch_dtype = typename payload_t::prefetch_dtype;
105105
constexpr uint32_t num_channel = payload_t::num_channel;
106106
#pragma unroll
107-
for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
108-
i++) {
107+
for (uint32_t i = 0; i < tile_desc::num_block_y; i++) {
109108
uint32_t offset_y = i * tile_desc::block_size_y;
110109
#pragma unroll
111110
for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
@@ -126,7 +125,6 @@ tile_prefetch(payload_t& payload) {
126125
L2>(
127126
payload.base_ptr,
128127
payload.channel_offset + payload.base_offset + address_offset);
129-
// }
130128
}
131129
}
132130
}

0 commit comments

Comments
 (0)