Skip to content

Commit bb5a79a

Browse files
Merge OpenAI Triton commit 993c8da (#4521)
This PR change the Triton base from 16961b7 to 993c8da (Jun 13). Pass rate: 97.11%->97.12%
2 parents e279ab7 + 82dc6ba commit bb5a79a

File tree

11 files changed

+100
-73
lines changed

11 files changed

+100
-73
lines changed

python/test/gluon/test_frontend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ def shared_memory_cast_kernel():
252252
smem = ttgl.allocate_shared_memory(ttgl.int8, [2, 256, 128], layout_a)
253253
perm = smem.index(0).permute((1, 0))
254254
ttgl.static_assert(perm.type.layout == layout_T)
255+
# Check that the MLIR type and Gluon types match by emitting a call.
256+
anchor_noinline(perm)
255257

256258
layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16,
257259
rank=4, cta_order=[3, 2, 1, 0])
@@ -279,11 +281,15 @@ def test_shared_memory_cast(fresh_knobs):
279281
%c0_i32_0 = arith.constant 0 : i32
280282
%1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
281283
%2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
284+
tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> ()
282285
%3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
283286
%4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
284287
%5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
285288
tt.return
286289
}
290+
tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) attributes {noinline = true} {
291+
tt.return
292+
}
287293
}
288294
""")
289295

@@ -318,6 +324,11 @@ def anchor(x):
318324
pass
319325

320326

327+
@gluon.jit(noinline=True)
328+
def anchor_noinline(x):
329+
pass
330+
331+
321332
@filecheck_test
322333
@gluon.jit
323334
def test_warp_specialize():

python/triton/experimental/gluon/language/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None
173173
out.append(self.to_ir(builder))
174174

175175
def __str__(self) -> str:
176-
return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>"
176+
return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
177177

178178
def __eq__(self, other) -> bool:
179179
return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def memdesc_trans(self, mem_desc, order):
191191

192192
handle = self.builder.create_memdesc_trans(mem_desc.handle, order)
193193
layout = self.builder.get_gluon_layout_from_memdesc(handle)
194-
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape, alloc_shape=alloc_shape,
195-
layout=layout)
194+
return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape,
195+
alloc_shape=new_alloc_shape, layout=layout)
196196

197197
def memdesc_reshape(self, mem_desc, shape, layout):
198198
ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape)

python/triton_kernels/tests/test_routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"):
1919
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 32), (1500, 8)])
2020
@pytest.mark.parametrize("use_expt_indx", [False, True])
2121
@pytest.mark.parametrize("sm_first", [True, False])
22-
@pytest.mark.skipif(is_hip, reason="Tests are currently broken on AMD")
22+
@pytest.mark.skipif(is_hip(), reason="Tests are currently broken on AMD")
2323
def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_expt_indx, device):
2424
torch.manual_seed(2)
2525
if n_tokens_raw is None:

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class OptFlags:
2121
split_k: int
2222
fused_scatter: bool
2323
is_persistent: bool
24+
idle_sms: int
2425
epilogue_subtile: int | None
2526
arch: str
2627
target_kernel_kwargs: dict
@@ -116,6 +117,7 @@ def make_default_opt_flags_amd(
116117
split_k=split_k,
117118
fused_scatter=constraints.get('fused_scatter', False),
118119
is_persistent=is_persistent,
120+
idle_sms=0,
119121
epilogue_subtile=constraints.get('epilogue_subtile', None),
120122
arch=None,
121123
target_kernel_kwargs=target_kernel_kwargs,
@@ -140,7 +142,7 @@ def make_default_opt_flags_nvidia(
140142
epilogue_effective_itemsize,
141143
constraints,
142144
):
143-
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages"]
145+
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms"]
144146
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
145147
# tokens per expert
146148
if routing_data is None:
@@ -236,6 +238,7 @@ def make_default_opt_flags_nvidia(
236238
epilogue_subtile=epilogue_subtile,
237239
arch=arch,
238240
target_kernel_kwargs=dict(),
241+
idle_sms=constraints.get("idle_sms", 0),
239242
)
240243
# check constraints
241244
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
@@ -283,7 +286,8 @@ def make_opt_flags(
283286
return _opt_flags
284287
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, microscaling_ctx, m, n, k,
285288
routing_data, can_use_persistent_tma, can_use_fused_scatter,
286-
enforce_bitwise_invariance, epilogue_effective_itemsize, _opt_flags_constraints]
289+
enforce_bitwise_invariance, epilogue_effective_itemsize,
290+
_opt_flags_constraints]
287291
backend = triton.runtime.driver.active.get_current_target().backend
288292
if backend == "hip":
289293
return make_default_opt_flags_amd(*args)

setup.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ class editable_wheel:
5050
from python.build_helpers import get_base_dir, get_cmake_dir
5151

5252

53+
def is_git_repo():
54+
"""Return True if this file resides in a git repository"""
55+
return (Path(__file__).parent / ".git").is_dir()
56+
57+
5358
@dataclass
5459
class Backend:
5560
name: str
@@ -71,13 +76,14 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
7176
assert backend_name in os.listdir(
7277
root_dir), f"{backend_name} is requested for install but not present in {root_dir}"
7378

74-
try:
75-
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
76-
stdout=subprocess.DEVNULL, cwd=root_dir)
77-
except subprocess.CalledProcessError:
78-
pass
79-
except FileNotFoundError:
80-
pass
79+
if is_git_repo():
80+
try:
81+
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
82+
stdout=subprocess.DEVNULL, cwd=root_dir)
83+
except subprocess.CalledProcessError:
84+
pass
85+
except FileNotFoundError:
86+
pass
8187

8288
backend_src_dir = os.path.join(root_dir, backend_name)
8389

@@ -775,7 +781,7 @@ def get_git_branch():
775781

776782

777783
def get_git_version_suffix():
778-
if not (Path(__file__).parent / ".git").is_dir():
784+
if not is_git_repo():
779785
return "" # Not a git checkout
780786
branch = get_git_branch()
781787
if branch.startswith("release"):

test/Conversion/cvt_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ tt.func private @convert_layout_blocked_blocked(%arg0: tensor<16x16xi32, #blocke
127127
// to this, we choose to fall back to the shared memory implementation.
128128

129129
// CHECK-NOT: shfl.sync.idx
130-
// CHECK: st.shared
130+
// CHECK: store
131131

132132
%0 = ttg.convert_layout %arg0 : tensor<16x16xi32, #blocked0> -> tensor<16x16xi32, #blocked1>
133133
tt.return %0 : tensor<16x16xi32, #blocked1>

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
804804
// CHECK-LABEL: convert_layout_blocked_blocked
805805
tt.func @convert_layout_blocked_blocked(%arg0: tensor<32x32xf32, #blocked0>) {
806806
// CHECK: llvm.mlir.addressof @global_smem
807-
// CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared
807+
// CHECK-COUNT-8: llvm.store
808808
// CHECK-: nvvm.barrier0
809809
// CHECK-COUNT-8: llvm.load
810810
%0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked0> -> tensor<32x32xf32, #blocked1>
@@ -821,10 +821,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
821821
// CHECK-LABEL: convert_layout_blocked_blocked_vec
822822
tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<32x32xf32, #blocked0>) {
823823
// CHECK: llvm.mlir.addressof @global_smem
824-
// CHECK: llvm.inline_asm
825-
// CHECK: st.shared
826-
// CHECK: llvm.inline_asm
827-
// CHECK: st.shared
824+
// CHECK: llvm.store
825+
// CHECK: llvm.store
828826
// CHECK: nvvm.barrier0
829827
// CHECK: llvm.load
830828
// CHECK: llvm.load
@@ -859,14 +857,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
859857
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
860858
tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
861859
// CHECK: llvm.mlir.addressof @global_smem
862-
// CHECK: llvm.inline_asm
863-
// CHECK: st.shared
860+
// CHECK: llvm.store
864861
// CHECK: nvvm.barrier0
865862
// CHECK: llvm.load
866863
// CHECK: llvm.load
867864
// CHECK: nvvm.barrier0
868-
// CHECK: llvm.inline_asm
869-
// CHECK: st.shared
865+
// CHECK: llvm.store
870866
// CHECK: nvvm.barrier0
871867
// CHECK: llvm.load
872868
// CHECK: llvm.load
@@ -1024,10 +1020,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10241020
// CHECK: llvm.mlir.global external @global_smem
10251021
// CHECK-LABEL: convert_layout_mmav2_block
10261022
tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
1027-
// CHECK: llvm.inline_asm
1028-
// CHECK-SAME: st.shared
1029-
// CHECK: llvm.inline_asm
1030-
// CHECK-SAME: st.shared
1023+
// CHECK: llvm.store
1024+
// CHECK: llvm.store
10311025
// CHECK: nvvm.barrier0
10321026
// CHECK: llvm.load
10331027
%0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0>
@@ -1042,7 +1036,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10421036
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10431037
// CHECK-LABEL: convert_layout_mmav2_dot_reg
10441038
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
1045-
// CHECK-NOT: st.shared
1039+
// CHECK-NOT: llvm.store
10461040
// CHECK-NOT: llvm.load
10471041
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
10481042
tt.return
@@ -1056,7 +1050,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10561050
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10571051
// CHECK-LABEL: convert_layout_mmav2_dot_reg
10581052
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) {
1059-
// CHECK-NOT: st.shared
1053+
// CHECK-NOT: llvm.store
10601054
// CHECK-NOT: llvm.load
10611055
%0 = ttg.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1>
10621056
tt.return
@@ -1072,7 +1066,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10721066
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10731067
// CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg
10741068
tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) {
1075-
// CHECK-NOT: st.shared
1069+
// CHECK-NOT: llvm.store
10761070
// CHECK-NOT: llvm.load
10771071
%0 = ttg.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked>
10781072
tt.return
@@ -1087,7 +1081,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10871081
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10881082
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
10891083
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
1090-
// CHECK-NOT: st.shared
1084+
// CHECK-NOT: llvm.store
10911085
// CHECK-NOT: llvm.load
10921086
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
10931087
tt.return
@@ -1102,7 +1096,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11021096
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11031097
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
11041098
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
1105-
// CHECK-NOT: st.shared
1099+
// CHECK-NOT: llvm.store
11061100
// CHECK-NOT: llvm.load
11071101
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
11081102
tt.return
@@ -1117,7 +1111,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11171111
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11181112
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
11191113
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
1120-
// CHECK-NOT: st.shared
1114+
// CHECK-NOT: llvm.store
11211115
// CHECK-NOT: llvm.load
11221116
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
11231117
tt.return
@@ -1132,7 +1126,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11321126
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11331127
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
11341128
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
1135-
// CHECK-NOT: st.shared
1129+
// CHECK-NOT: llvm.store
11361130
// CHECK-NOT: llvm.load
11371131
%0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
11381132
tt.return
@@ -1146,7 +1140,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11461140
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
11471141
// CHECK-LABEL: convert_layout_mmav2_dot_reg
11481142
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
1149-
// CHECK-NOT: st.shared
1143+
// CHECK-NOT: llvm.store
11501144
// CHECK-NOT: llvm.load
11511145
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
11521146
tt.return
@@ -1161,7 +1155,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
11611155
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11621156
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
11631157
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
1164-
// CHECK-NOT: st.shared
1158+
// CHECK-NOT: llvm.store
11651159
// CHECK-NOT: llvm.load
11661160
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
11671161
tt.return
@@ -1176,7 +1170,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11761170
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11771171
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
11781172
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
1179-
// CHECK-NOT: st.shared
1173+
// CHECK-NOT: llvm.store
11801174
// CHECK-NOT: llvm.load
11811175
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
11821176
tt.return
@@ -1191,7 +1185,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11911185
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11921186
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
11931187
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
1194-
// CHECK-NOT: st.shared
1188+
// CHECK-NOT: llvm.store
11951189
// CHECK-NOT: llvm.load
11961190
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
11971191
tt.return
@@ -1206,7 +1200,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12061200
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12071201
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
12081202
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
1209-
// CHECK-NOT: st.shared
1203+
// CHECK-NOT: llvm.store
12101204
// CHECK-NOT: llvm.load
12111205
%0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
12121206
tt.return
@@ -1221,28 +1215,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
12211215
// CHECK: llvm.mlir.global external @global_smem
12221216
// CHECK-LABEL: convert_layout_mmav3_transpose
12231217
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
1224-
// CHECK-COUNT-16: st.shared.b8
1218+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12251219
// CHECK: nvvm.barrier0
12261220
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1227-
// CHECK-COUNT-16: st.shared.b8
1221+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12281222
// CHECK: nvvm.barrier0
12291223
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1230-
// CHECK-COUNT-16: st.shared.b8
1224+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12311225
// CHECK: nvvm.barrier0
12321226
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1233-
// CHECK-COUNT-16: st.shared.b8
1227+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12341228
// CHECK: nvvm.barrier0
12351229
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1236-
// CHECK-COUNT-16: st.shared.b8
1230+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12371231
// CHECK: nvvm.barrier0
12381232
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1239-
// CHECK-COUNT-16: st.shared.b8
1233+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12401234
// CHECK: nvvm.barrier0
12411235
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1242-
// CHECK-COUNT-16: st.shared.b8
1236+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12431237
// CHECK: nvvm.barrier0
12441238
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1245-
// CHECK-COUNT-16: st.shared.b8
1239+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12461240
// CHECK: nvvm.barrier0
12471241
// CHECK: llvm.load {{.*}} -> vector<4xi32>
12481242
%0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
@@ -1301,7 +1295,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
13011295
// CHECK-LABEL: convert_blocked_to_blocked_ptr
13021296
tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
13031297
// CHECK: llvm.ptrtoint
1304-
// CHECK: inline_asm{{.*}}st.shared
1298+
// CHECK: llvm.store
13051299
// CHECK: nvvm.barrier0
13061300
// CHECK: llvm.inttoptr
13071301
// CHECK-COUNT-4: llvm.insertvalue
@@ -1319,13 +1313,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
13191313
// CHECK-LABEL: linear_layout_with_multiple_iterations
13201314
tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) {
13211315
%cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1>
1322-
// CHECK: inline_asm{{.*}}st.shared.v2
1316+
// CHECK: llvm.store {{.*}} : vector<2xi16>
13231317
// CHECK: nvvm.barrier0
1324-
// CHECK: llvm.load
1318+
// CHECK: llvm.load {{.*}} -> i16
13251319
// CHECK: nvvm.barrier0
1326-
// CHECK: inline_asm{{.*}}st.shared.v2
1320+
// CHECK: llvm.store {{.*}} : vector<2xi16>
13271321
// CHECK: nvvm.barrier0
1328-
// CHECK: llvm.load
1322+
// CHECK: llvm.load {{.*}} -> i16
13291323
tt.return
13301324
}
13311325
}

0 commit comments

Comments
 (0)