@@ -1170,10 +1170,9 @@ struct mem_payload_t<
1170
1170
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
1171
1171
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
1172
1172
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1173
- static constexpr uint32_t tile_bytes =
1174
- tile_size_x * tile_size_y * sizeof (dtype);
1173
+ static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof (dtype);
1175
1174
static constexpr uint32_t block_bytes =
1176
- block_size_x * block_size_y * sizeof (dtype);
1175
+ tile_desc::block_elems * sizeof (dtype);
1177
1176
using this_payload_t =
1178
1177
mem_payload_t <mem_desc_t , tile_desc, msg_type::block_2d, arch_tag_>;
1179
1178
@@ -1250,7 +1249,7 @@ struct mem_payload_t<
1250
1249
base_offset = mem_transpose
1251
1250
? base_x * pitch_in_bytes + base_y * sizeof (dtype)
1252
1251
: base_y * pitch_in_bytes + base_x * sizeof (dtype);
1253
- base_ptr = ( mem_dtype*) mem_tdesc.base .base ;
1252
+ base_ptr = reinterpret_cast < mem_dtype*>( mem_tdesc.base .base ) ;
1254
1253
1255
1254
xetla_vector<uint32_t , num_channel> channel_index =
1256
1255
xetla_vector_gen<uint32_t , num_channel>(0 , 1 );
@@ -1734,10 +1733,8 @@ struct prefetch_payload_t<
1734
1733
static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
1735
1734
static constexpr uint32_t block_size_x = tile_desc::block_size_x;
1736
1735
static constexpr uint32_t block_size_y = tile_desc::block_size_y;
1737
- static constexpr uint32_t tile_bytes =
1738
- tile_size_x * tile_size_y * sizeof (dtype);
1739
- static constexpr uint32_t block_bytes =
1740
- block_size_x * block_size_y * sizeof (dtype);
1736
+ static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof (dtype);
1737
+ static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof (dtype);
1741
1738
1742
1739
private:
1743
1740
using this_payload_t =
@@ -1751,67 +1748,75 @@ struct prefetch_payload_t<
1751
1748
static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
1752
1749
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
1753
1750
1754
- using prefetch_dtype = typename std::conditional <
1751
+ using prefetch_dtype = typename std::conditional_t <
1755
1752
(alignment_in_bytes % (sizeof (uint64_t )) == 0 ),
1756
1753
uint64_t ,
1757
- typename std::conditional <
1754
+ typename std::conditional_t <
1758
1755
(alignment_in_bytes % (sizeof (uint32_t )) == 0 ),
1759
1756
uint32_t ,
1760
- dtype>::type>::type ;
1757
+ dtype>> ;
1761
1758
static constexpr uint32_t pack_factor =
1762
1759
sizeof (prefetch_dtype) / sizeof (dtype);
1763
1760
1764
- static constexpr uint32_t min_store_bytes = 16 * sizeof (dtype);
1765
- static constexpr uint32_t max_store_bytes = 32 * sizeof (dtype);
1766
- static constexpr uint32_t simd_channel =
1767
- ((tile_bytes % max_store_bytes) == 0 &&
1768
- (block_bytes % max_store_bytes) == 0 )
1769
- ? 32
1770
- : 16 ;
1771
- static constexpr uint32_t num_channel = mem_transpose
1772
- ? (simd_channel >= block_size_x) ? block_size_x : simd_channel
1773
- : (simd_channel >= block_size_y) ? block_size_y
1774
- : simd_channel;
1761
+ static constexpr uint32_t vector_size =
1762
+ ((mem_transpose ? block_size_y : block_size_x) + pack_factor - 1 ) /
1763
+ pack_factor;
1775
1764
1776
- static constexpr uint32_t vector_size = mem_transpose
1777
- ? (block_size_y + pack_factor - 1 ) / pack_factor
1778
- : (block_size_x + pack_factor - 1 ) / pack_factor ;
1765
+ using load_store_attr = load_store_attr_t <msg_type::block_1d, arch_tag>;
1766
+ static constexpr uint32_t max_prefetch_vec_len =
1767
+ load_store_attr::max_prefetch_vec_len ;
1779
1768
1780
- static constexpr uint32_t mem_tile_size_w =
1781
- mem_transpose ? tile_size_y : tile_size_x;
1782
- static constexpr uint32_t mem_tile_size_h =
1783
- mem_transpose ? tile_size_x : tile_size_y;
1784
- using load_store_attr =
1785
- typename arch_attr_t <arch_tag>::template load_store_attr<message_type>;
1786
- static constexpr uint32_t special_prefetch_width =
1787
- load_store_attr::special_prefetch_width_in_bytes / sizeof (dtype);
1788
- static constexpr uint32_t normal_prefetch_width =
1789
- load_store_attr::max_load_width_in_bytes / sizeof (dtype);
1790
- static constexpr bool is_special_prefetch =
1791
- (mem_tile_size_w % special_prefetch_width) == 0 ;
1769
+ static constexpr uint32_t max_channel =
1770
+ max_prefetch_vec_len / (vector_size * sizeof (prefetch_dtype));
1792
1771
1793
- static constexpr uint32_t block_size_w = is_special_prefetch
1794
- ? special_prefetch_width
1795
- : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1796
- : normal_prefetch_width);
1797
- static constexpr uint32_t block_size_h =
1798
- load_store_attr::max_load_height_in_elem;
1799
- // could have over-prefetch, but that's should be fine
1800
- static constexpr uint32_t max_num_block_w =
1801
- (mem_tile_size_w + block_size_w - 1 ) / block_size_w;
1802
- static constexpr uint32_t num_coop_sg = num_coop_sg_;
1803
- static constexpr uint32_t num_coop_sg_w =
1804
- detail::gcd<num_coop_sg, max_num_block_w>::value;
1805
- static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1772
+ static constexpr uint32_t select_channel (const uint32_t channel) {
1773
+ return (channel >= load_store_attr::max_channel_num)
1774
+ ? load_store_attr::max_channel_num
1775
+ : channel >= 16 ? 16
1776
+ : channel >= 8 ? 8
1777
+ : 1 ;
1778
+ }
1806
1779
1807
- static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1808
- static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1809
- static constexpr uint32_t tile_size_h =
1810
- (mem_tile_size_h + num_coop_sg_h - 1 ) / num_coop_sg_h;
1811
- static constexpr uint32_t num_block_h =
1812
- (tile_size_h + block_size_h - 1 ) / block_size_h;
1780
+ static constexpr uint32_t num_channel = select_channel(
1781
+ std::min (mem_transpose ? block_size_x : block_size_y, max_channel));
1782
+
1783
+ // static constexpr uint32_t mem_tile_size_w =
1784
+ // mem_transpose ? tile_size_y : tile_size_x;
1785
+ // static constexpr uint32_t mem_tile_size_h =
1786
+ // mem_transpose ? tile_size_x : tile_size_y;
1787
+
1788
+ // static constexpr uint32_t special_prefetch_width =
1789
+ // load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
1790
+ // static constexpr uint32_t normal_prefetch_width =
1791
+ // load_store_attr::max_load_width_in_bytes / sizeof(dtype);
1792
+ // static constexpr bool is_special_prefetch =
1793
+ // (mem_tile_size_w % special_prefetch_width) == 0;
1794
+
1795
+ // static constexpr uint32_t block_size_w = is_special_prefetch
1796
+ // ? special_prefetch_width
1797
+ // : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
1798
+ // : normal_prefetch_width);
1799
+ // static constexpr uint32_t block_size_h =
1800
+ // load_store_attr::max_load_height_in_elem;
1801
+ // // could have over-prefetch, but that's should be fine
1802
+ // static constexpr uint32_t max_num_block_w =
1803
+ // (mem_tile_size_w + block_size_w - 1) / block_size_w;
1804
+ // static constexpr uint32_t num_coop_sg = num_coop_sg_;
1805
+ // static constexpr uint32_t num_coop_sg_w =
1806
+ // detail::gcd<num_coop_sg, max_num_block_w>::value;
1807
+ // static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
1808
+
1809
+ // static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
1810
+ // static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
1811
+ // static constexpr uint32_t tile_size_h =
1812
+ // (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
1813
+ // static constexpr uint32_t num_block_h =
1814
+ // (tile_size_h + block_size_h - 1) / block_size_h;
1813
1815
1814
1816
xetla_vector<uint32_t , num_channel> channel_offset;
1817
+ xetla_vector<uint32_t , num_channel> step_x;
1818
+ xetla_vector<uint32_t , num_channel> step_y;
1819
+
1815
1820
uint64_t base_offset;
1816
1821
uint32_t base_x;
1817
1822
uint32_t base_y;
@@ -1848,13 +1853,15 @@ struct prefetch_payload_t<
1848
1853
return *this ;
1849
1854
}
1850
1855
1851
- inline prefetch_payload_t (mem_desc_t & mem_desc, uint32_t coop_id = 0 ) {
1852
- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1853
- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1856
+ inline prefetch_payload_t (
1857
+ mem_desc_t & mem_desc,
1858
+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1859
+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1860
+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1854
1861
1855
1862
pitch_in_bytes = mem_desc.shape .stride * sizeof (dtype);
1856
- base_x = mem_desc.coord .x + coop_id_x * tile_size_w ;
1857
- base_y = mem_desc.coord .y + coop_id_y * tile_size_h ;
1863
+ base_x = mem_desc.coord .x ;
1864
+ base_y = mem_desc.coord .y ;
1858
1865
width_in_elems = mem_desc.shape .x ;
1859
1866
height_in_elems = mem_desc.shape .y ;
1860
1867
base_offset = mem_transpose
@@ -1874,13 +1881,15 @@ struct prefetch_payload_t<
1874
1881
int surface_pitch,
1875
1882
int surface_offset_x,
1876
1883
int surface_offset_y,
1877
- uint32_t coop_id = 0 ) {
1878
- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1879
- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1884
+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1885
+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1886
+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1887
+ // base_x = surface_offset_x + coop_id_x * tile_size_w;
1888
+ // base_y = surface_offset_y + coop_id_y * tile_size_h;
1880
1889
1881
1890
pitch_in_bytes = surface_pitch * sizeof (dtype);
1882
- base_x = surface_offset_x + coop_id_x * tile_size_w ;
1883
- base_y = surface_offset_y + coop_id_y * tile_size_h ;
1891
+ base_x = surface_offset_x;
1892
+ base_y = surface_offset_y;
1884
1893
width_in_elems = surface_width;
1885
1894
height_in_elems = surface_height;
1886
1895
base_offset = mem_transpose
@@ -1893,13 +1902,17 @@ struct prefetch_payload_t<
1893
1902
channel_offset = channel_index * pitch_in_bytes;
1894
1903
}
1895
1904
1896
- inline void init (mem_desc_t & mem_desc, uint32_t coop_id = 0 ) {
1897
- uint32_t coop_id_x = coop_id % num_coop_sg_w;
1898
- uint32_t coop_id_y = coop_id / num_coop_sg_w;
1905
+ inline void init (
1906
+ mem_desc_t & mem_desc,
1907
+ [[maybe_unused]] uint32_t coop_id = 0 ) {
1908
+ // uint32_t coop_id_x = coop_id % num_coop_sg_w;
1909
+ // uint32_t coop_id_y = coop_id / num_coop_sg_w;
1910
+ // base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
1911
+ // base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
1899
1912
1900
1913
pitch_in_bytes = mem_desc.shape .stride * sizeof (dtype);
1901
- base_x = mem_desc.coord .x + coop_id_x * tile_size_w ;
1902
- base_y = mem_desc.coord .y + coop_id_y * tile_size_h ;
1914
+ base_x = mem_desc.coord .x ;
1915
+ base_y = mem_desc.coord .y ;
1903
1916
width_in_elems = mem_desc.shape .x ;
1904
1917
height_in_elems = mem_desc.shape .y ;
1905
1918
base_offset = mem_transpose
0 commit comments