@@ -60,7 +60,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
60
60
61
61
// -----
62
62
63
- // COM: Test while loop / tt.advance before tt.load (TODO)
63
+ // COM: Test while loop / nested tt.advance
64
64
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 4 ], warpsPerCTA = [32 , 1 ], order = [1 , 0 ]}>
65
65
#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
66
66
// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
@@ -99,8 +99,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
99
99
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]>
100
100
%5 = tt.dot %4 , %cstB , %cst , inputPrecision = tf32 : tensor <256 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <32 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
101
101
%6 = ttg.convert_layout %5 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked1 >
102
- // COM: TODO: support nested tt.advance
103
- // %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
102
+ // CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
103
+ %7 = tt.advance %a_ptr_crt , [%c0_i32 , %c32_i32 ] : <tensor <256 x32 xf16 , #blocked1 >>
104
104
105
105
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
106
106
scf.yield %a_ptr_crt : !tt.ptr <tensor <256 x32 xf16 , #blocked1 >>
0 commit comments