Skip to content

Commit 4bc28e2

Browse files
authored
[optimize-dot-operands]: Fuse load and trans operations - part 2 (#4468)
This PR enhances the new transformation pass aimed at fusing `tt.load` and `tt.trans` operations. Specifically it adds support for loop carried arguments used (possibly transitively) by the candidate `tt.load` that should be fused with a `tt.trans`. Example: ``` %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>> %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 { %17 = tt.advance %9, [%c256_i32, %arg5] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> %18 = tt.load %17 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> %19 = tt.advance %arg6, [%c16_i32, %arg5] : <tensor<256x32xbf16, #linear>> %20 = tt.load %19 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xbf16, #linear>> %21 = tt.trans %20 {order = array<i32: 1, 0>} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> %23 = arith.addi %arg5, %c32_i32 : i32 scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>> } ``` Here the load `%20` is a candidate for fusion with the `tt.trans` operation. The pointer argument used by the candidate load (`%19`) is produced by a `tt.advance` operation which uses the loop carried pointer `%arg6`. --------- Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
1 parent c10bebc commit 4bc28e2

File tree

7 files changed

+496
-130
lines changed

7 files changed

+496
-130
lines changed

test/TritonIntelGPU/dot-operands.mlir

Lines changed: 217 additions & 11 deletions
Large diffs are not rendered by default.

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def make_ttgir(mod, metadata, opt, properties):
281281
intel.passes.ttgpuir.add_accelerate_matmul(pm)
282282
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
283283
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
284+
intel.passes.ttgpuir.add_optimize_dot_operands(pm)
284285
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt))
285286

286287
if (opt.reduce_variable_liveness):

third_party/intel/include/Utils/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ bool isConstant(Value val, int64_t expected);
2929

3030
Value getFinalValue(Value value);
3131

32+
// Erase the operations in \p operations.
33+
void eraseOperations(SmallPtrSetImpl<Operation *> &operations);
34+
3235
} // namespace mlir::triton::intel
3336

3437
#endif // TRITON_INTEL_UTILS_UTILITY_H

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ struct TritonIntelTensorDescToBlockPointer
9595
.Default([&](auto) { return WalkResult::advance(); });
9696
});
9797

98-
finalize();
98+
if (!cleanUp.empty())
99+
tt::intel::eraseOperations(cleanUp);
100+
99101
assert(succeeded(verify(moduleOp)) && "Module verification failed");
100102
}
101103

@@ -267,31 +269,6 @@ struct TritonIntelTensorDescToBlockPointer
267269
return success();
268270
}
269271

270-
void finalize() {
271-
// Cleanup unused operations.
272-
bool erasedOperation;
273-
do {
274-
erasedOperation = false;
275-
SmallPtrSet<Operation *, 8> erased;
276-
for (Operation *op : cleanUp) {
277-
if (!op->getUsers().empty() || !op->getRegions().empty())
278-
continue;
279-
280-
erased.insert(op);
281-
op->erase();
282-
erasedOperation = true;
283-
}
284-
cleanUp.remove_if([&](Operation *op) { return erased.contains(op); });
285-
} while (erasedOperation);
286-
287-
// Remove operations that contain a region.
288-
for (Operation *op : cleanUp) {
289-
if (!op->getUsers().empty())
290-
continue;
291-
op->erase();
292-
}
293-
}
294-
295272
private:
296273
SmallPtrSet<Operation *, 8> cleanUp;
297274
};

0 commit comments

Comments
 (0)