Skip to content

Commit a915a90

Browse files
committed
propagate layout to tt.advance
1 parent 82d9ce2 commit a915a90

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

test/TritonIntelGPU/optimize-block-io-encoding.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
6060

6161
// -----
6262

63-
// COM: Test while loop / tt.advance before tt.load (TODO)
63+
// COM: Test while loop / nested tt.advance
6464
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
6565
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
6666
// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
@@ -99,8 +99,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
9999
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]>
100100
%5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
101101
%6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1>
102-
// COM: TODO: support nested tt.advance
103-
// %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
102+
// CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
103+
%7 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
104104

105105
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
106106
scf.yield %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>>

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
117117
} else if (isa<scf::YieldOp>(op)) {
118118
auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber());
119119
updateEncoding(vals, EncodingInfo{encoding});
120+
} else if (isa<AdvanceOp>(op)) {
121+
// The operand will be updated when the MakeTensorPtr op result is
122+
// updated. Make sure the result type matches.
123+
for (auto result : op->getResults())
124+
if (auto desc = dyn_cast<TypedValue<PointerType>>(result))
125+
updateEncoding(desc, EncodingInfo{encoding});
120126
}
121127
}
122128

0 commit comments

Comments
 (0)