From 0ca51af26facb34ab907b4476d38603243373a0f Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 4 Jun 2025 21:01:26 +0000 Subject: [PATCH 01/14] [WIP]: Fuse load and trans operations Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/dot-operands.mlir | 33 +++ third_party/intel/backend/compiler.py | 1 + .../TritonIntelGPU/Transforms/Passes.td | 14 + .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../OptimizeDotOperands.cpp | 262 ++++++++++++++++++ third_party/intel/triton_xpu.cc | 2 + 6 files changed, 313 insertions(+) create mode 100644 test/TritonIntelGPU/dot-operands.mlir create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir new file mode 100644 index 0000000000..2b0be6ebec --- /dev/null +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -0,0 +1,33 @@ +// RUN: triton-opt %s -split-input-file -tritonintelgpu-optimize-dot-operands -canonicalize | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { + + // COM: tt.load -> tt.trans -> tt.dot chain, not in a loop. + // COM: Expecting the load to be "fused" with the transpose + tt.func public @fuseLoadWithTrans(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %1 = tt.load %arg0 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %2 = tt.advance %0, [%c256_i32, %c0_i32] : > + %3 = tt.load %2 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %4 = tt.trans %3 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> + tt.return + } + // CHECK-LABEL: fuseLoadWithTrans + // CHECK-NOT: tt.trans + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: [[ADV:%.*]] = tt.advance [[PTR]], [%c0_i32, %c256_i32] : >> + // CHECK: [[LOAD_B:%.*]] = tt.load [[ADV]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 0fdfe86d52..05d2bd4636 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -294,6 +294,7 @@ def make_ttgir(mod, metadata, opt, properties): passes.ttgpuir.add_fuse_nested_loops(pm) passes.ttgpuir.add_optimize_thread_locality(pm) + intel.passes.ttgpuir.add_optimize_dot_operands(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.common.add_cse(pm) passes.ttgpuir.add_prefetch(pm) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 48175ef66b..071af8eeb7 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -27,6 +27,19 @@ def TritonIntelGPUAccelerateMatmul ]; } +def TritonIntelGPUOptimizeDotOperands + : Pass<"tritonintelgpu-optimize-dot-operands", "mlir::ModuleOp"> { + let summary = "Intel optimize dot operands"; + + let description = [{ + Re-arranged layouts of tensors used as matrix multiplication operands to + promote the use of hardware-accelerated operations. + }]; + + let dependentDialects = ["mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; +} + def TritonIntelGPUCoalesce : Pass<"tritonintelgpu-coalesce", "mlir::ModuleOp"> { let summary = "Intel Coalesce"; @@ -382,4 +395,5 @@ def TritonIntelGPURewriteStackPtr "mlir::arith::ArithDialect" ]; } + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 0bc36f03ba..2424572139 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeDotOperands.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp Pipeliner/SoftwarePipeliner.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp new file mode 100644 index 0000000000..7bee076f2f --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -0,0 +1,262 @@ +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Utils/Utility.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace mlir::triton::gpu::intel { +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEDOTOPERANDS +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" +} // namespace mlir::triton::gpu::intel + +namespace { + +// Transform: +// %ptr = make_block_ptr [shX, shX], [stX, stY], [offX, offY] +// : tt.ptr +// %load = tt.load %ptr, {blockIO=} +// : tt.ptr +// %trans = tt.trans %load : tt.ptr> +// tt.dot(%a, %trans) +// into: +// %ptr = make_block_ptr [shX, shX], [stX, stY], [offX, offY] +// : tt.ptr +// %load = tt.load %ptr, {blockIO=} +// : tt.ptr +// tt.dot(%a, %load) +class FuseTransWithLoad : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tt::TransOp transOp, + PatternRewriter &rewriter) const override { + if (!isCandidate(transOp)) + return failure(); + + auto tensorType = cast(transOp.getType()); + Attribute dotEncoding = + cast(tensorType.getEncoding()); + auto loadOp = cast(transOp.getSrc().getDefiningOp()); + tt::MakeTensorPtrOp makeTensorPtrOp = + *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); + llvm::errs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"; + + // Create a MakeTensorPtrOp yielding a block pointer to the transposed + // tensor. + auto ptrType = cast(makeTensorPtrOp.getType()); + auto newPtrType = + tt::PointerType::get(tensorType, ptrType.getAddressSpace()); + SmallVector newShape(llvm::reverse(makeTensorPtrOp.getShape())); + SmallVector newStrides(llvm::reverse(makeTensorPtrOp.getStrides())); + SmallVector newOffsets(llvm::reverse(makeTensorPtrOp.getOffsets())); + + OpBuilder builder(makeTensorPtrOp); + Value ptr = builder.create( + makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), + newShape, newStrides, newOffsets, makeTensorPtrOp.getOrderAttr()); + assert(makeTensorPtrOp->hasOneUse() && "Expecing single user"); + llvm::errs() << "newMakeTensorPtrOp: " << ptr << "\n"; + + // Transitively update users of the block pointer. + Operation *makeTensorPtrOpUser = *makeTensorPtrOp->getUsers().begin(); + if (auto advanceOp = dyn_cast(makeTensorPtrOpUser)) { + llvm::errs() << "user is advance: " << advanceOp << "\n"; + ptr = updateAdvanceOpChain(advanceOp, loadOp, ptr); + } else { + // TODO: handle loop init args (scf.for only for now). + assert(makeTensorPtrOpUser == loadOp && + "Expecting the load to be the user"); + } + + // Replace the load+transpose with a new load operation that uses the + // transposed block pointer. + auto newLoadOp = rewriter.create( + loadOp.getLoc(), ptr, loadOp.getMask(), loadOp.getOther(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + llvm::errs() << "newLoadOp: " << newLoadOp << "\n"; + + StringRef blockIOAttrName = + ttgi::TritonIntelGPUDialect::getBlockIOAttrName(); + StringAttr attr = loadOp->getAttrOfType(blockIOAttrName); + StringAttr newAttr = + (attr == "row_major") + ? StringAttr::get(loadOp->getContext(), "column_major") + : (attr == "column_major") + ? StringAttr::get(loadOp->getContext(), "row_major") + : nullptr; + assert(newAttr && "Expecting a valid blockIO attribute"); + + newLoadOp->setAttr(blockIOAttrName, newAttr); + + transOp->replaceAllUsesWith(newLoadOp); + + [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); + moduleOp->dumpPretty(); + assert(succeeded(verify(moduleOp)) && "Module verification failed"); + + return success(); + } + +private: + // Candidate is of the form: + // tt.dot(tt.trans(tt.load(..., {blockIO=...}))) + // Where: + // - the transpose result is used only by the dot operation, and + // - the transpose operation uses the result of a 2-dim load operation on a + // block pointer (transitively) defined by a `make_tensor_ptr` in the same + // function, and + // - each operation in the def-use chain origination at the `make_tensor_ptr` + // and terminating at the load has a single user. + bool isCandidate(tt::TransOp transOp) const { + assert(transOp && "Expecting a valid transpose operation"); + + bool transOpUsedOnlyByDotOp = + transOp->hasOneUse() && + isa(*transOp->getUsers().begin()); + Attribute transOpEncoding = transOp.getType().getEncoding(); + if (!transOpUsedOnlyByDotOp || !transOpEncoding || + !isa(transOpEncoding)) + return false; + + Operation *defOp = transOp.getSrc().getDefiningOp(); + if (!defOp || !isa(defOp)) + return false; + + llvm::errs() << "at line " << __LINE__ << "\n"; + return isCandidate(cast(defOp)); + } + + bool isCandidate(tt::LoadOp loadOp) const { + assert(loadOp && "Expecting a valid load operation"); + + bool loadOpHasBlockIOAttr = loadOp->hasAttrOfType( + ttgi::TritonIntelGPUDialect::getBlockIOAttrName()); + if (!loadOp->hasOneUse() || !loadOpHasBlockIOAttr) + return false; + + llvm::errs() << "at line " << __LINE__ << "\n"; + auto ptrType = cast(loadOp.getPtr().getType()); + if (!isTensorPointerType(ptrType) || + cast(ptrType.getPointeeType()).getRank() != 2) + return false; + + llvm::errs() << "at line " << __LINE__ << "\n"; + std::optional defOp = + *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); + if (!defOp || !singleUsersInChain(*defOp, loadOp)) { + llvm::errs() << "at line " << __LINE__ << "\n"; + return false; + } + llvm::errs() << "at line " << __LINE__ << "\n"; + return true; + } + + bool singleUsersInChain(Operation *start, Operation *end) const { + assert(start && end && "Expecting valid operations"); + Operation *currentOp = start; + while (currentOp != end) { + llvm::errs() << "currentOp: " << *currentOp << "\n"; + if (!currentOp->hasOneUse()) { + llvm::errs() << "at line " << __LINE__ << "\n"; + return false; + } + + currentOp = *currentOp->getUsers().begin(); + if (auto forOp = dyn_cast(currentOp)) { + for (BlockArgument arg : forOp.getRegionIterArgs()) { + Value initArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; + if (initArg == currentOp->getResult(0)) { + if (!arg.hasOneUse()) { + llvm::errs() << "at line " << __LINE__ << "\n"; + return false; + } + + currentOp = *arg.getUsers().begin(); + break; + } + } + } + } + + llvm::errs() << "at line " << __LINE__ << "\n"; + return true; + } + + // Recursively update the operands in a chain of AdvanceOps, after setting the + // pointer operand of the first one. + tt::AdvanceOp updateAdvanceOpChain(tt::AdvanceOp advanceOp, tt::LoadOp loadOp, + Value ptr) const { + assert(advanceOp->hasOneUse() && "Expecting single user"); + assert(tt::isTensorPointerType(ptr.getType()) && + "Expecting a block pointer"); + + Operation *user = *advanceOp->getUsers().begin(); + if (auto loadUser = dyn_cast(user)) { + assert(loadUser == loadOp && + "chain should be terminated by candidate load"); + OpBuilder rewriter(advanceOp); + SmallVector newOffsets(llvm::reverse(advanceOp.getOffsets())); + return rewriter.create(advanceOp.getLoc(), ptr.getType(), + ptr, newOffsets); + } + + if (auto advanceOp = dyn_cast(user)) { + OpBuilder rewriter(advanceOp); + SmallVector newOffsets(llvm::reverse(advanceOp.getOffsets())); + ptr = rewriter.create(advanceOp.getLoc(), ptr.getType(), + ptr, newOffsets); + return updateAdvanceOpChain(advanceOp, loadOp, ptr); + } + + llvm::errs() << "user: " << *user << "\n"; + llvm_unreachable("Unexpected user"); + + return nullptr; + } +}; + +} // namespace + +class TritonIntelGPUOptimizeDotOperandsPass + : public triton::gpu::intel::impl::TritonIntelGPUOptimizeDotOperandsBase< + TritonIntelGPUOptimizeDotOperandsPass> { +public: + using triton::gpu::intel::impl::TritonIntelGPUOptimizeDotOperandsBase< + TritonIntelGPUOptimizeDotOperandsPass>:: + TritonIntelGPUOptimizeDotOperandsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + OpPassManager pm; + pm.addPass(mlir::createCanonicalizerPass()); + if (failed(runPipeline(pm, m))) + return signalPassFailure(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 387ac71045..589e86ffeb 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -78,6 +78,8 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { enum gpu::intel::SplitBarrierScope); ADD_PASS_WRAPPER_0("add_remove_layout_conversions", gpu::intel::createTritonIntelGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_optimize_dot_operands", + gpu::intel::createTritonIntelGPUOptimizeDotOperands); ADD_PASS_WRAPPER_0("add_coalesce", gpu::intel::createTritonIntelGPUCoalesce); ADD_PASS_OPTION_WRAPPER_2("add_prefetch_block", gpu::intel::createTritonIntelGPUPrefetchBlock, int, From 07599e3449b501082743cb3a19f0737c8320e3e4 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 5 Jun 2025 17:10:44 +0000 Subject: [PATCH 02/14] Limit candidates to operations with no associated region. Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/dot-operands.mlir | 42 +++++++++++++-- .../OptimizeDotOperands.cpp | 53 ++++++------------- 2 files changed, 54 insertions(+), 41 deletions(-) diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index 2b0be6ebec..dd1109f072 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -6,8 +6,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { // COM: tt.load -> tt.trans -> tt.dot chain, not in a loop. - // COM: Expecting the load to be "fused" with the transpose - tt.func public @fuseLoadWithTrans(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + tt.func public @fuseLoadWithTrans1(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c256_i32 = arith.constant 256 : i32 @@ -23,11 +22,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> tt.return } - // CHECK-LABEL: fuseLoadWithTrans + // CHECK-LABEL: fuseLoadWithTrans1 // CHECK-NOT: tt.trans // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> // CHECK: [[ADV:%.*]] = tt.advance [[PTR]], [%c0_i32, %c256_i32] : >> // CHECK: [[LOAD_B:%.*]] = tt.load [[ADV]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. + // COM: where the 'make_tensor_ptr' result is not loop carried. + tt.func public @fuseLoadWithTrans2(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32, #mma>, i32) : i32 { + %1 = tt.load %arg0 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %2 = tt.advance %0, [%c256_i32, %c0_i32] : > + %3 = tt.load %2 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %4 = tt.trans %3 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %5 = tt.dot %1, %4, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %6 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %5, %6 : tensor<256x256xf32, #mma>, i32 + } + %6 = ttg.convert_layout %res#0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> + tt.return + } + // CHECK-LABEL: fuseLoadWithTrans2 + // CHECK-NOT: tt.trans + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: scf.for {{.*}} + // CHECK: [[ADV:%.*]] = tt.advance [[PTR]], [%c0_i32, %c256_i32] : >> + // CHECK: [[LOAD_B:%.*]] = tt.load [[ADV]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + // CHECK: scf.yield + + + + + } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 7bee076f2f..cb7d6b7dc9 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -1,3 +1,5 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Utils/Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -6,18 +8,17 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" - -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" -#include "intel/include/Utils/Utility.h" - #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" #include +#define DEBUG_TYPE "tritonintelgpu-optimize-dot-operands" + using namespace mlir; namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; @@ -52,13 +53,14 @@ class FuseTransWithLoad : public OpRewritePattern { if (!isCandidate(transOp)) return failure(); + LLVM_DEBUG(llvm::dbgs() << "Candidate: " << transOp << "\n"); auto tensorType = cast(transOp.getType()); Attribute dotEncoding = cast(tensorType.getEncoding()); auto loadOp = cast(transOp.getSrc().getDefiningOp()); tt::MakeTensorPtrOp makeTensorPtrOp = *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); - llvm::errs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"; + LLVM_DEBUG(llvm::dbgs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"); // Create a MakeTensorPtrOp yielding a block pointer to the transposed // tensor. @@ -73,13 +75,12 @@ class FuseTransWithLoad : public OpRewritePattern { Value ptr = builder.create( makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), newShape, newStrides, newOffsets, makeTensorPtrOp.getOrderAttr()); - assert(makeTensorPtrOp->hasOneUse() && "Expecing single user"); - llvm::errs() << "newMakeTensorPtrOp: " << ptr << "\n"; + assert(makeTensorPtrOp->hasOneUse() && "Expecting single user"); + LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp: " << ptr << "\n"); // Transitively update users of the block pointer. Operation *makeTensorPtrOpUser = *makeTensorPtrOp->getUsers().begin(); if (auto advanceOp = dyn_cast(makeTensorPtrOpUser)) { - llvm::errs() << "user is advance: " << advanceOp << "\n"; ptr = updateAdvanceOpChain(advanceOp, loadOp, ptr); } else { // TODO: handle loop init args (scf.for only for now). @@ -93,7 +94,6 @@ class FuseTransWithLoad : public OpRewritePattern { loadOp.getLoc(), ptr, loadOp.getMask(), loadOp.getOther(), loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - llvm::errs() << "newLoadOp: " << newLoadOp << "\n"; StringRef blockIOAttrName = ttgi::TritonIntelGPUDialect::getBlockIOAttrName(); @@ -107,11 +107,11 @@ class FuseTransWithLoad : public OpRewritePattern { assert(newAttr && "Expecting a valid blockIO attribute"); newLoadOp->setAttr(blockIOAttrName, newAttr); + LLVM_DEBUG(llvm::dbgs() << "newLoadOp: " << newLoadOp << "\n"); transOp->replaceAllUsesWith(newLoadOp); [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); - moduleOp->dumpPretty(); assert(succeeded(verify(moduleOp)) && "Module verification failed"); return success(); @@ -142,7 +142,6 @@ class FuseTransWithLoad : public OpRewritePattern { if (!defOp || !isa(defOp)) return false; - llvm::errs() << "at line " << __LINE__ << "\n"; return isCandidate(cast(defOp)); } @@ -154,51 +153,31 @@ class FuseTransWithLoad : public OpRewritePattern { if (!loadOp->hasOneUse() || !loadOpHasBlockIOAttr) return false; - llvm::errs() << "at line " << __LINE__ << "\n"; auto ptrType = cast(loadOp.getPtr().getType()); if (!isTensorPointerType(ptrType) || cast(ptrType.getPointeeType()).getRank() != 2) return false; - llvm::errs() << "at line " << __LINE__ << "\n"; std::optional defOp = *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); - if (!defOp || !singleUsersInChain(*defOp, loadOp)) { - llvm::errs() << "at line " << __LINE__ << "\n"; + if (!defOp || !singleUsersInChain(*defOp, loadOp)) return false; - } - llvm::errs() << "at line " << __LINE__ << "\n"; + return true; } bool singleUsersInChain(Operation *start, Operation *end) const { assert(start && end && "Expecting valid operations"); Operation *currentOp = start; + while (currentOp != end) { - llvm::errs() << "currentOp: " << *currentOp << "\n"; - if (!currentOp->hasOneUse()) { - llvm::errs() << "at line " << __LINE__ << "\n"; + // TODO: extend to handle loops. + if ((currentOp->getNumRegions() != 0) || !currentOp->hasOneUse()) return false; - } currentOp = *currentOp->getUsers().begin(); - if (auto forOp = dyn_cast(currentOp)) { - for (BlockArgument arg : forOp.getRegionIterArgs()) { - Value initArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; - if (initArg == currentOp->getResult(0)) { - if (!arg.hasOneUse()) { - llvm::errs() << "at line " << __LINE__ << "\n"; - return false; - } - - currentOp = *arg.getUsers().begin(); - break; - } - } - } } - llvm::errs() << "at line " << __LINE__ << "\n"; return true; } @@ -228,9 +207,7 @@ class FuseTransWithLoad : public OpRewritePattern { return updateAdvanceOpChain(advanceOp, loadOp, ptr); } - llvm::errs() << "user: " << *user << "\n"; llvm_unreachable("Unexpected user"); - return nullptr; } }; From b1a2c1f486a247fa9007bd1e469dd5328920e2bc Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 6 Jun 2025 13:56:38 +0000 Subject: [PATCH 03/14] Allow candidates in for loop Signed-off-by: Tiotto, Ettore --- .../OptimizeDotOperands.cpp | 86 +++++++++++++++++-- 1 file changed, 78 insertions(+), 8 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index cb7d6b7dc9..0c2be71720 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -1,5 +1,3 @@ -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" -#include "intel/include/Utils/Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -8,12 +6,16 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Utils/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -107,11 +109,11 @@ class FuseTransWithLoad : public OpRewritePattern { assert(newAttr && "Expecting a valid blockIO attribute"); newLoadOp->setAttr(blockIOAttrName, newAttr); - LLVM_DEBUG(llvm::dbgs() << "newLoadOp: " << newLoadOp << "\n"); + LLVM_DEBUG(llvm::errs() << "newLoadOp: " << newLoadOp << "\n"); transOp->replaceAllUsesWith(newLoadOp); - [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); + [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); assert(succeeded(verify(moduleOp)) && "Module verification failed"); return success(); @@ -166,16 +168,84 @@ class FuseTransWithLoad : public OpRewritePattern { return true; } + // Determine whether all operations in the def-use chain from \p start to + // \p end have a single user. + // Note: we allow an operation in the def-use chain to have an additional user + // if the operation is in a for loop, and the additional user is the yield + // operation, provided that the result yielded is not used after the loop. + // Example: + // make_tensor_ptr -> advance -> load (OK) + // make_tensor_ptr -> for init_arg -> advance -> load (OK) + // -> yield (OK) + // make_tensor_ptr -> for init_arg -> advance -> load (OK) + // -> yield -> load (NOT OK) + // bool singleUsersInChain(Operation *start, Operation *end) const { assert(start && end && "Expecting valid operations"); Operation *currentOp = start; - while (currentOp != end) { - // TODO: extend to handle loops. - if ((currentOp->getNumRegions() != 0) || !currentOp->hasOneUse()) + auto validate = [](Operation *op, Operation *&nextOp) { + assert(nextOp == nullptr); + + if (op->hasOneUse()) + return true; + if (!op->getParentOfType()) return false; - currentOp = *currentOp->getUsers().begin(); + SmallVector users(op->getUsers()); + if (users.size() > 2 || llvm::none_of(users, [](Operation *op) { + return isa(op); + })) + return false; + + auto yieldOp = cast(*llvm::find_if( + users, [](Operation *user) { return isa(user); })); + auto yieldedValUsedAfterLoop = + [&op, &yieldOp]() { + auto it = llvm::find_if(yieldOp->getOpOperands(), + [&op](OpOperand &operand) { + return operand.get() == op->getResult(0); + }); + assert(it != yieldOp->getOpOperands().end()); + OpOperand &operand = *it; + auto forOp = cast(yieldOp->getParentOp()); + OpResult res = forOp->getResult(operand.getOperandNumber()); + return !res.getUsers().empty(); + }; + if (yieldedValUsedAfterLoop()) + return false; + + nextOp = *llvm::find_if( + users, [](Operation *user) { return !isa(user); }); + return true; + }; + + while (currentOp != end) { + Operation *user = nullptr; + if (!validate(currentOp, user)) { + LLVM_DEBUG(llvm::dbgs() << currentOp << " fails safety checks\n"); + return false; + } + + user = (!user) ? user = *currentOp->getUsers().begin() : user; + if (user->getNumRegions() == 0) { + currentOp = user; + continue; + } + + // Find the next operation in the def-use chain inside the lop body. + if (auto forOp = dyn_cast(user)) { + for (BlockArgument arg : forOp.getRegionIterArgs()) { + Value initArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; + if (initArg == currentOp->getResult(0)) { + if (!arg.hasOneUse()) + return false; + + currentOp = *arg.getUsers().begin(); + break; + } + } + } } return true; From 5eafc6b8c19fcbd70a4aa0466727515c1b6724ad Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 6 Jun 2025 13:57:32 +0000 Subject: [PATCH 04/14] Fix precommit Signed-off-by: Tiotto, Ettore --- .../OptimizeDotOperands.cpp | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 0c2be71720..50bebc4a8d 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -1,3 +1,5 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Utils/Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -6,8 +8,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" -#include "intel/include/Utils/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" @@ -113,7 +113,7 @@ class FuseTransWithLoad : public OpRewritePattern { transOp->replaceAllUsesWith(newLoadOp); - [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); + [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); assert(succeeded(verify(moduleOp)) && "Module verification failed"); return success(); @@ -200,18 +200,17 @@ class FuseTransWithLoad : public OpRewritePattern { auto yieldOp = cast(*llvm::find_if( users, [](Operation *user) { return isa(user); })); - auto yieldedValUsedAfterLoop = - [&op, &yieldOp]() { - auto it = llvm::find_if(yieldOp->getOpOperands(), - [&op](OpOperand &operand) { - return operand.get() == op->getResult(0); - }); - assert(it != yieldOp->getOpOperands().end()); - OpOperand &operand = *it; - auto forOp = cast(yieldOp->getParentOp()); - OpResult res = forOp->getResult(operand.getOperandNumber()); - return !res.getUsers().empty(); - }; + auto yieldedValUsedAfterLoop = [&op, &yieldOp]() { + auto it = + llvm::find_if(yieldOp->getOpOperands(), [&op](OpOperand &operand) { + return operand.get() == op->getResult(0); + }); + assert(it != yieldOp->getOpOperands().end()); + OpOperand &operand = *it; + auto forOp = cast(yieldOp->getParentOp()); + OpResult res = forOp->getResult(operand.getOperandNumber()); + return !res.getUsers().empty(); + }; if (yieldedValUsedAfterLoop()) return false; From 5181bb3c9c4c944d9a0ecbd0e99957bd52dd38fc Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Fri, 6 Jun 2025 15:05:19 +0000 Subject: [PATCH 05/14] Better traces Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 50bebc4a8d..93a61890c2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -55,14 +55,14 @@ class FuseTransWithLoad : public OpRewritePattern { if (!isCandidate(transOp)) return failure(); - LLVM_DEBUG(llvm::dbgs() << "Candidate: " << transOp << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Found candidate:\n\t" << transOp << "\n"); auto tensorType = cast(transOp.getType()); Attribute dotEncoding = cast(tensorType.getEncoding()); auto loadOp = cast(transOp.getSrc().getDefiningOp()); tt::MakeTensorPtrOp makeTensorPtrOp = *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); - LLVM_DEBUG(llvm::dbgs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"); + LLVM_DEBUG(llvm::dbgs() << "makeTensorPtrOp:\n\t" << makeTensorPtrOp << "\n"); // Create a MakeTensorPtrOp yielding a block pointer to the transposed // tensor. @@ -78,7 +78,7 @@ class FuseTransWithLoad : public OpRewritePattern { makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), newShape, newStrides, newOffsets, makeTensorPtrOp.getOrderAttr()); assert(makeTensorPtrOp->hasOneUse() && "Expecting single user"); - LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp: " << ptr << "\n"); + LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp:\n\t" << ptr << "\n"); // Transitively update users of the block pointer. Operation *makeTensorPtrOpUser = *makeTensorPtrOp->getUsers().begin(); @@ -276,6 +276,9 @@ class FuseTransWithLoad : public OpRewritePattern { return updateAdvanceOpChain(advanceOp, loadOp, ptr); } + // TODO: add support for loops (advanceOp cound be consumed by a loop + // init_arg). + llvm_unreachable("Unexpected user"); return nullptr; } From 2329dd7e866b28895f0d5f10337783a386ce8cef Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 9 Jun 2025 21:59:02 +0000 Subject: [PATCH 06/14] Allow fusing load+trans when load ptr is loop carried Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/dot-operands.mlir | 127 ++++++++- .../OptimizeDotOperands.cpp | 249 +++++++++++++----- 2 files changed, 303 insertions(+), 73 deletions(-) diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index dd1109f072..119ffad693 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritonintelgpu-optimize-dot-operands -canonicalize | FileCheck %s +// RUN: triton-opt %s -tritonintelgpu-optimize-dot-operands | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> @@ -18,8 +18,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th %3 = tt.load %2 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> %4 = tt.trans %3 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> - %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> - tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> tt.return } // CHECK-LABEL: fuseLoadWithTrans1 @@ -49,8 +47,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th %6 = arith.addi %arg5, %c32_i32 : i32 scf.yield %5, %6 : tensor<256x256xf32, #mma>, i32 } - %6 = ttg.convert_layout %res#0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> - tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> tt.return } // CHECK-LABEL: fuseLoadWithTrans2 @@ -62,8 +58,127 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> // CHECK: scf.yield + // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. + // COM: where the 'make_tensor_ptr' result is loop carried. + tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %c16_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { + %17 = tt.advance %9, [%c256_i32, %arg5] : >> + %18 = tt.load %17 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %19 = tt.advance %arg6, [%c16_i32, %arg5] : > + %20 = tt.load %19 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %21 = tt.trans %20 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %23 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr> + } + tt.return + } + // CHECK-LABEL: fuseLoadWithTrans3 + // CHECK-NOT: tt.trans + // CHECK [[IDX1:%.*]] = arith.muli + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: scf.for {{.*}} iter_args({{.*}}, [[ARG5:%.*]] = {{.*}}, [[ARG6:%.*]] = [[PTR]]) + // CHECK: [[ADV:%.*]] = tt.advance [[ARG6]], [[[ARG5]], %c16_i32] : >> + // CHECK: [[LOAD_B:%.*]] = tt.load [[ADV]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + // CHECK: scf.yield {{.*}}, {{.*}}, [[ADV]] + // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load + // COM: that 'feeds' the transpose operation is used. + tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %c16_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { + %17 = tt.advance %9, [%c256_i32, %arg5] : >> + %18 = tt.load %17 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %19 = tt.advance %arg6, [%c16_i32, %arg5] : > + %20 = tt.load %19 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %21 = tt.trans %20 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %23 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr> + } + %15 = tt.advance %13#2, [%c16_i32, %c16_i32] : > + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans1 + // CHECK: tt.trans - + // COM: Ensure load is not fused with transpose if there are multiple users in the loop body. + tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<256x32xbf16, #linear> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %c16_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { + %17 = tt.advance %9, [%c256_i32, %arg5] : >> + %18 = tt.load %17 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %19 = tt.advance %arg6, [%c16_i32, %arg5] : > + %20 = tt.load %19 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %21 = tt.trans %20 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + tt.store %19, %cst_1 {boundaryCheck = array} : !tt.ptr> + %23 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr> + } + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans2 + // CHECK: tt.trans } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 93a61890c2..2e236f0dcb 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -1,8 +1,10 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Utils/Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -46,27 +48,39 @@ namespace { // %load = tt.load %ptr, {blockIO=} // : tt.ptr // tt.dot(%a, %load) -class FuseTransWithLoad : public OpRewritePattern { +class FuseTransWithLoad { +private: + tt::FuncOp funcOp; + SmallPtrSet cleanUp; + public: - using OpRewritePattern::OpRewritePattern; + FuseTransWithLoad(tt::FuncOp funcOp) : funcOp(funcOp) {} + + void run() { + funcOp.walk([&](tt::TransOp transOp) { + if (isCandidate(transOp)) + fuse(transOp); + }); - LogicalResult matchAndRewrite(tt::TransOp transOp, - PatternRewriter &rewriter) const override { - if (!isCandidate(transOp)) - return failure(); + if (!cleanUp.empty()) + finalize(); + + [[maybe_unused]] auto moduleOp = funcOp->getParentOfType(); + assert(succeeded(verify(moduleOp)) && "Module verification failed"); + } + void fuse(tt::TransOp transOp) { LLVM_DEBUG(llvm::dbgs() << "Found candidate:\n\t" << transOp << "\n"); - auto tensorType = cast(transOp.getType()); - Attribute dotEncoding = - cast(tensorType.getEncoding()); auto loadOp = cast(transOp.getSrc().getDefiningOp()); tt::MakeTensorPtrOp makeTensorPtrOp = *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); - LLVM_DEBUG(llvm::dbgs() << "makeTensorPtrOp:\n\t" << makeTensorPtrOp << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "makeTensorPtrOp:\n\t" << makeTensorPtrOp << "\n"); // Create a MakeTensorPtrOp yielding a block pointer to the transposed - // tensor. + // tensor... auto ptrType = cast(makeTensorPtrOp.getType()); + auto tensorType = cast(transOp.getType()); auto newPtrType = tt::PointerType::get(tensorType, ptrType.getAddressSpace()); SmallVector newShape(llvm::reverse(makeTensorPtrOp.getShape())); @@ -80,43 +94,8 @@ class FuseTransWithLoad : public OpRewritePattern { assert(makeTensorPtrOp->hasOneUse() && "Expecting single user"); LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp:\n\t" << ptr << "\n"); - // Transitively update users of the block pointer. - Operation *makeTensorPtrOpUser = *makeTensorPtrOp->getUsers().begin(); - if (auto advanceOp = dyn_cast(makeTensorPtrOpUser)) { - ptr = updateAdvanceOpChain(advanceOp, loadOp, ptr); - } else { - // TODO: handle loop init args (scf.for only for now). - assert(makeTensorPtrOpUser == loadOp && - "Expecting the load to be the user"); - } - - // Replace the load+transpose with a new load operation that uses the - // transposed block pointer. - auto newLoadOp = rewriter.create( - loadOp.getLoc(), ptr, loadOp.getMask(), loadOp.getOther(), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - - StringRef blockIOAttrName = - ttgi::TritonIntelGPUDialect::getBlockIOAttrName(); - StringAttr attr = loadOp->getAttrOfType(blockIOAttrName); - StringAttr newAttr = - (attr == "row_major") - ? StringAttr::get(loadOp->getContext(), "column_major") - : (attr == "column_major") - ? StringAttr::get(loadOp->getContext(), "row_major") - : nullptr; - assert(newAttr && "Expecting a valid blockIO attribute"); - - newLoadOp->setAttr(blockIOAttrName, newAttr); - LLVM_DEBUG(llvm::errs() << "newLoadOp: " << newLoadOp << "\n"); - - transOp->replaceAllUsesWith(newLoadOp); - - [[maybe_unused]] auto moduleOp = newLoadOp->getParentOfType(); - assert(succeeded(verify(moduleOp)) && "Module verification failed"); - - return success(); + // ... and propagate it through the def-use chain. + propagateToUsers(ptr, makeTensorPtrOp, makeTensorPtrOp, transOp); } private: @@ -161,7 +140,7 @@ class FuseTransWithLoad : public OpRewritePattern { return false; std::optional defOp = - *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); + triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); if (!defOp || !singleUsersInChain(*defOp, loadOp)) return false; @@ -222,7 +201,7 @@ class FuseTransWithLoad : public OpRewritePattern { while (currentOp != end) { Operation *user = nullptr; if (!validate(currentOp, user)) { - LLVM_DEBUG(llvm::dbgs() << currentOp << " fails safety checks\n"); + LLVM_DEBUG(llvm::dbgs() << *currentOp << " fails safety checks\n"); return false; } @@ -232,7 +211,7 @@ class FuseTransWithLoad : public OpRewritePattern { continue; } - // Find the next operation in the def-use chain inside the lop body. + // Find the next operation in the def-use chain inside the loop body. if (auto forOp = dyn_cast(user)) { for (BlockArgument arg : forOp.getRegionIterArgs()) { Value initArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; @@ -278,34 +257,170 @@ class FuseTransWithLoad : public OpRewritePattern { // TODO: add support for loops (advanceOp cound be consumed by a loop // init_arg). - + llvm_unreachable("Unexpected user"); return nullptr; } + + // Propagate \p newVal to users of \p origOp. + void propagateToUsers(Value newVal, Value origVal, Operation *origOp, + Operation *sentinel) { + assert(origOp && sentinel && "Expecting valid operations"); + const SmallVector users(origOp->getUsers()); + for (Operation *user : users) + propagateToUser(newVal, origVal, user, sentinel); + } + + // If \p user is not \p sentinel, propagate \p newVal to \p user. Otherwise + // terminate the propagation. + void propagateToUser(Value newVal, Value origVal, Operation *user, + Operation *sentinel) { + assert(user && sentinel && "Expecting valid operations"); + assert(llvm::is_contained(origVal.getUsers(), user) && "Invalid usage"); + + LLVM_DEBUG({ + llvm::dbgs() << "In " << __func__ << "\n"; + llvm::dbgs() << "user of "; + if (origVal.getDefiningOp()) { + llvm::dbgs() << "\n\t" << *origVal.getDefiningOp() << "\n"; + } else { + origVal.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << " "; + } + llvm::dbgs() << "is:\n\t"; + user->dumpPretty(); + }); + + if (user == sentinel) { + LLVM_DEBUG(llvm::dbgs() << "Reached sentinel\n"); + sentinel->replaceAllUsesWith(newVal.getDefiningOp()); + cleanUp.insert(sentinel); + return; + } + + Location loc = user->getLoc(); + if (auto advanceOp = dyn_cast(user)) { + OpBuilder rewriter(advanceOp); + SmallVector newOffsets(llvm::reverse(advanceOp.getOffsets())); + auto newAdvanceOp = rewriter.create(loc, newVal.getType(), + newVal, newOffsets); + LLVM_DEBUG(llvm::dbgs() << "\tnewAdvanceOp: " << newAdvanceOp << "\n"); + cleanUp.insert(advanceOp); + return propagateToUsers(newAdvanceOp, advanceOp.getResult(), advanceOp, + sentinel); + } + + if (auto loadOp = dyn_cast(user)) { + OpBuilder rewriter(loadOp); + auto newLoadOp = rewriter.create( + loadOp.getLoc(), newVal, loadOp.getMask(), loadOp.getOther(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + + StringRef blockIOAttrName = + ttgi::TritonIntelGPUDialect::getBlockIOAttrName(); + StringAttr attr = loadOp->getAttrOfType(blockIOAttrName); + StringAttr newAttr = + (attr == "row_major") + ? StringAttr::get(loadOp->getContext(), "column_major") + : (attr == "column_major") + ? StringAttr::get(loadOp->getContext(), "row_major") + : nullptr; + assert(newAttr && "Expecting a valid blockIO attribute"); + + newLoadOp->setAttr(blockIOAttrName, newAttr); + LLVM_DEBUG(llvm::dbgs() << "\tnewLoadOp: " << newLoadOp << "\n"); + cleanUp.insert(loadOp); + return propagateToUsers(newLoadOp, loadOp.getResult(), loadOp, sentinel); + } + + if (auto yieldOp = dyn_cast(user)) { + int opNum = -1; + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == origVal) { + opNum = operand.getOperandNumber(); + yieldOp->setOperand(operand.getOperandNumber(), newVal); + break; + } + } + + // Update the yield's parent operation result type. + Operation *parentOp = yieldOp->getParentOp(); + for (OpResult res : parentOp->getOpResults()) { + int resNum = res.getResultNumber(); + if (resNum == opNum) + res.setType(newVal.getType()); + } + return; + } + + if (auto forOp = dyn_cast(user)) + return propagateToLoop(newVal, origVal, forOp, sentinel); + } + + void propagateToLoop(Value newVal, Value origVal, LoopLikeOpInterface loopOp, + Operation *sentinel) { + assert(sentinel && sentinel != loopOp && "Unexpected sentinel kind"); + LLVM_DEBUG({ + llvm::dbgs() << "In " << __func__ << "\n"; + llvm::dbgs() << "newVal: " << newVal << "\n"; + }); + + for (auto [initArg, rgnInitArg, yieldVal, loopRes] : + llvm::zip(loopOp.getInitsMutable(), loopOp.getRegionIterArgs(), + loopOp.getYieldedValues(), loopOp->getResults())) { + if (initArg.get() == origVal) { + initArg.set(newVal); + rgnInitArg.setType(initArg.get().getType()); + const SmallVector users(rgnInitArg.getUsers()); + for (Operation *user : users) + propagateToUser(rgnInitArg, rgnInitArg, user, sentinel); + } + } + } + + // Cleanup unused operations. + void finalize() { + bool erasedOperation; + do { + erasedOperation = false; + SmallPtrSet erased; + for (Operation *op : cleanUp) { + if (!op->getUsers().empty() || !op->getRegions().empty()) + continue; + + erased.insert(op); + op->erase(); + erasedOperation = true; + } + cleanUp.remove_if([&](Operation *op) { return erased.contains(op); }); + } while (erasedOperation); + + // Remove operations that contain a region. + for (Operation *op : cleanUp) { + if (!op->getUsers().empty()) + continue; + op->erase(); + } + } }; } // namespace class TritonIntelGPUOptimizeDotOperandsPass - : public triton::gpu::intel::impl::TritonIntelGPUOptimizeDotOperandsBase< + : public ttgi::impl::TritonIntelGPUOptimizeDotOperandsBase< TritonIntelGPUOptimizeDotOperandsPass> { + public: - using triton::gpu::intel::impl::TritonIntelGPUOptimizeDotOperandsBase< + using ttgi::impl::TritonIntelGPUOptimizeDotOperandsBase< TritonIntelGPUOptimizeDotOperandsPass>:: TritonIntelGPUOptimizeDotOperandsBase; - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp m = getOperation(); - - OpPassManager pm; - pm.addPass(mlir::createCanonicalizerPass()); - if (failed(runPipeline(pm, m))) - return signalPassFailure(); - - mlir::RewritePatternSet patterns(context); - patterns.add(context); - if (failed(applyPatternsGreedily(m, std::move(patterns)))) - signalPassFailure(); + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + moduleOp.walk([](tt::FuncOp funcOp) { + FuseTransWithLoad fuser(funcOp); + fuser.run(); + }); } }; From 475eef7926c31adf7165cd1d414158e7d1dc961a Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 10 Jun 2025 14:08:33 +0000 Subject: [PATCH 07/14] Fix failing tutorial 09 Signed-off-by: Tiotto, Ettore --- third_party/intel/lib/Utils/Utility.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/intel/lib/Utils/Utility.cpp b/third_party/intel/lib/Utils/Utility.cpp index e5b23bd7ca..277e7c5807 100644 --- a/third_party/intel/lib/Utils/Utility.cpp +++ b/third_party/intel/lib/Utils/Utility.cpp @@ -1,6 +1,7 @@ #include "intel/include/Utils/Utility.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include @@ -49,6 +50,8 @@ std::optional findDefiningMakeTensorPtrOp(Value val) { return findDefiningMakeTensorPtrOp(loopArg); } + if (auto poisonOp = val.getDefiningOp()) + return std::nullopt; if (auto advanceOp = val.getDefiningOp()) return findDefiningMakeTensorPtrOp(advanceOp.getPtr()); if (auto makePtrOp = val.getDefiningOp()) From 617dc0d039e25247084d2467873381bcf28741da Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 12 Jun 2025 21:54:48 +0000 Subject: [PATCH 08/14] Allow trans user to be any operation as long as def-use chain end is tt.dot Signed-off-by: Tiotto, Ettore --- third_party/intel/backend/compiler.py | 2 +- .../OptimizeDotOperands.cpp | 34 ++++++++++++++----- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 0f346b4fe1..009f4b6c14 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_optimize_dot_operands(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) if (opt.reduce_variable_liveness): @@ -287,7 +288,6 @@ def make_ttgir(mod, metadata, opt, properties): passes.ttgpuir.add_fuse_nested_loops(pm) passes.ttgpuir.add_optimize_thread_locality(pm) - intel.passes.ttgpuir.add_optimize_dot_operands(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.common.add_cse(pm) passes.ttgpuir.add_prefetch(pm) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 70fd1f692a..727cc5318b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -99,7 +99,7 @@ class FuseTransWithLoad { // Candidate is of the form: // tt.dot(tt.trans(tt.load(..., {blockIO=...}))) // Where: - // - the transpose result is used only by the dot operation, and + // - the transpose result is used by the dot operation, and // - the transpose operation uses the result of a 2-dim load operation on a // block pointer (transitively) defined by a `make_tensor_ptr` in the same // function, and @@ -108,11 +108,26 @@ class FuseTransWithLoad { bool isCandidate(tt::TransOp transOp) const { assert(transOp && "Expecting a valid transpose operation"); - bool transOpUsedOnlyByDotOp = - transOp->hasOneUse() && - isa(*transOp->getUsers().begin()); + // Check whether \p transOp is used by a `dotOp` directly or indirectly + // (each operation in the def-use chain need to have a single user). + auto usedByDotOp = [](tt::TransOp transOp) { + if (!transOp->hasOneUse()) + return false; + + Operation *user = *transOp->getUsers().begin(); + while (user) { + if (isa(user)) + return true; + if (!user->hasOneUse()) + break; + user = *user->getUsers().begin(); + } + + return false; + }; + Attribute transOpEncoding = transOp.getType().getEncoding(); - if (!transOpUsedOnlyByDotOp || !transOpEncoding || + if (!usedByDotOp(transOp) || !transOpEncoding || !isa(transOpEncoding)) return false; @@ -165,14 +180,16 @@ class FuseTransWithLoad { if (op->hasOneUse()) return true; - if (!op->getParentOfType()) + if (!op->getParentOfType()) { return false; + } SmallVector users(op->getUsers()); if (users.size() > 2 || llvm::none_of(users, [](Operation *op) { return isa(op); - })) + })) { return false; + } auto yieldOp = cast(*llvm::find_if( users, [](Operation *user) { return isa(user); })); @@ -187,8 +204,9 @@ class FuseTransWithLoad { OpResult res = forOp->getResult(operand.getOperandNumber()); return !res.getUsers().empty(); }; - if (yieldedValUsedAfterLoop()) + if (yieldedValUsedAfterLoop()) { return false; + } nextOp = *llvm::find_if( users, [](Operation *user) { return !isa(user); }); From dd8979df01666c92ec96a1d995cd6076a0538b94 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 17 Jun 2025 15:43:55 +0000 Subject: [PATCH 09/14] Address code review comments Signed-off-by: Tiotto, Ettore --- .../OptimizeDotOperands.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 727cc5318b..e51bd2b472 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -51,10 +51,10 @@ class FuseTransWithLoad { SmallPtrSet cleanUp; public: - FuseTransWithLoad(tt::FuncOp funcOp) : funcOp(funcOp) {} + FuseTransWithLoad() = default; - void run() { - funcOp.walk([&](tt::TransOp transOp) { + void run(ModuleOp moduleOp) { + moduleOp.walk([&](tt::TransOp transOp) { if (isCandidate(transOp)) fuse(transOp); }); @@ -62,7 +62,6 @@ class FuseTransWithLoad { if (!cleanUp.empty()) tt::intel::eraseOperations(cleanUp); - [[maybe_unused]] auto moduleOp = funcOp->getParentOfType(); assert(succeeded(verify(moduleOp)) && "Module verification failed"); } @@ -375,9 +374,7 @@ class TritonIntelGPUOptimizeDotOperandsPass void runOnOperation() final { ModuleOp moduleOp = getOperation(); - moduleOp.walk([](tt::FuncOp funcOp) { - FuseTransWithLoad fuser(funcOp); - fuser.run(); - }); + FuseTransWithLoad fuser; + fuser.run(moduleOp); } }; From d3cb92bcdc39d3e73de9f4853d15ab4ff096ea21 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Tue, 17 Jun 2025 18:09:42 +0000 Subject: [PATCH 10/14] Address code review comments Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/dot-operands.mlir | 43 ++++++++++++++++++- .../OptimizeDotOperands.cpp | 22 +++++----- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index 6c94650856..90e9e0a2b9 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -169,7 +169,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { - // COM: Ensure load is not fused with transpose if there are multiple users in the loop body. + // COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose. + // COM: In this case `%19` is used by the load that feeds the transpose and by a store operation. tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { %c4_i32 = arith.constant 4 : i32 %c1024_i32 = arith.constant 1024 : i32 @@ -208,3 +209,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // CHECK-LABEL: doNotFuseLoadWithTrans2 // CHECK: tt.trans } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { + // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. + // COM: where the 'make_tensor_ptr' result is not loop carried. + tt.func public @doNotFuseLoadWithTrans3(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %a = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + + %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32, #mma>, i32) : i32 { + %1 = tt.load %arg0 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %2 = tt.advance %0, [%c256_i32, %c0_i32] : > + %3 = scf.if %cond -> !tt.ptr> { + scf.yield %2 : !tt.ptr> + } else { + scf.yield %a : !tt.ptr> + } + %a4 = tt.load %3 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %b4 = tt.load %2 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %5 = tt.trans %b4 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %6 = tt.dot %1, %5, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %7 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %6, %7 : tensor<256x256xf32, #mma>, i32 + } + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans3 + // CHECK: tt.trans +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index e51bd2b472..d7b3ca87b0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -179,19 +179,18 @@ class FuseTransWithLoad { if (op->hasOneUse()) return true; - if (!op->getParentOfType()) { + if (!op->getParentOfType()) return false; - } + + auto forOp = op->getParentOfType(); + auto yieldOp = cast(forOp.getBody()->getTerminator()); SmallVector users(op->getUsers()); - if (users.size() > 2 || llvm::none_of(users, [](Operation *op) { - return isa(op); - })) { + if (users.size() > 2 || llvm::none_of(users, [&](Operation *user) { + return user == yieldOp; + })) return false; - } - auto yieldOp = cast(*llvm::find_if( - users, [](Operation *user) { return isa(user); })); auto yieldedValUsedAfterLoop = [&op, &yieldOp]() { auto it = llvm::find_if(yieldOp->getOpOperands(), [&op](OpOperand &operand) { @@ -203,9 +202,9 @@ class FuseTransWithLoad { OpResult res = forOp->getResult(operand.getOperandNumber()); return !res.getUsers().empty(); }; - if (yieldedValUsedAfterLoop()) { + + if (yieldedValUsedAfterLoop()) return false; - } nextOp = *llvm::find_if( users, [](Operation *user) { return !isa(user); }); @@ -215,7 +214,8 @@ class FuseTransWithLoad { while (currentOp != end) { Operation *user = nullptr; if (!validate(currentOp, user)) { - LLVM_DEBUG(llvm::dbgs() << *currentOp << " fails safety checks\n"); + LLVM_DEBUG(llvm::dbgs() + << "Fails safety checks: " << *currentOp << "\n"); return false; } From c1a6949cdc2d956b948fb6d6a1a5ec4a14a36bcd Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 18 Jun 2025 14:26:21 +0000 Subject: [PATCH 11/14] Address code review comments Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index d7b3ca87b0..40a3b3430b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -347,9 +347,8 @@ class FuseTransWithLoad { llvm::dbgs() << "newVal: " << newVal << "\n"; }); - for (auto [initArg, rgnInitArg, yieldVal, loopRes] : - llvm::zip(loopOp.getInitsMutable(), loopOp.getRegionIterArgs(), - loopOp.getYieldedValues(), loopOp->getResults())) { + for (auto [initArg, rgnInitArg] : + llvm::zip(loopOp.getInitsMutable(), loopOp.getRegionIterArgs())) { if (initArg.get() == origVal) { initArg.set(newVal); rgnInitArg.setType(initArg.get().getType()); From ceace6ca58c302d465e4f54a2a6641517258da12 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 19 Jun 2025 17:23:26 +0000 Subject: [PATCH 12/14] Simplify unit test Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/dot-operands.mlir | 30 +++++++++++---------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index 90e9e0a2b9..cf240fafe5 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -1,11 +1,10 @@ // RUN: triton-opt %s -split-input-file -tritonintelgpu-optimize-dot-operands | FileCheck %s -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, not in a loop. - tt.func public @fuseLoadWithTrans1(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + tt.func public @fuseLoadWithTrans1(%arg0: !tt.ptr>>, %arg1: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c256_i32 = arith.constant 256 : i32 @@ -29,13 +28,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. // COM: where the 'make_tensor_ptr' result is not loop carried. - tt.func public @fuseLoadWithTrans2(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + tt.func public @fuseLoadWithTrans2(%arg0: !tt.ptr>>, %arg1: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c32_i32 = arith.constant 32 : i32 @@ -67,13 +65,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. // COM: where the 'make_tensor_ptr' result is loop carried. - tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %c4_i32 = arith.constant 4 : i32 %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -119,13 +116,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load // COM: that 'feeds' the transpose operation is used. - tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %c4_i32 = arith.constant 4 : i32 %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -165,13 +161,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose. // COM: In this case `%19` is used by the load that feeds the transpose and by a store operation. - tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %c4_i32 = arith.constant 4 : i32 %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -212,13 +207,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. // COM: where the 'make_tensor_ptr' result is not loop carried. - tt.func public @doNotFuseLoadWithTrans3(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %cond: i1) { + tt.func public @doNotFuseLoadWithTrans3(%arg0: !tt.ptr>>, %arg1: !tt.ptr, %cond: i1) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c32_i32 = arith.constant 32 : i32 From 920ae90b00c3885e668e532515e5e2bd8db40ca6 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 19 Jun 2025 20:24:15 +0000 Subject: [PATCH 13/14] Address code review comments Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/dot-operands.mlir | 35 +++++++++++++++++-- .../OptimizeDotOperands.cpp | 33 +++++++++-------- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index cf240fafe5..babdf2e34a 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -210,8 +210,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { - // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. - // COM: where the 'make_tensor_ptr' result is not loop carried. + // COM: Ensure load is not fused with transpose if the block ptr used by the load operation is yielded by a if statement (current limitation). tt.func public @doNotFuseLoadWithTrans3(%arg0: !tt.ptr>>, %arg1: !tt.ptr, %cond: i1) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 @@ -243,3 +242,35 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // CHECK-LABEL: doNotFuseLoadWithTrans3 // CHECK: tt.trans } + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { + // COM: Ensure load is not fused with transpose when it is in a while loop (current limitation). + tt.func public @doNotFuseLoadWithTrans4(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %1:2 = scf.while (%arg3 = %0, %arg4 = %c0_i32) : (!tt.ptr>, i32) -> (!tt.ptr>, i32) { + scf.condition(%arg2) %arg3, %arg4 : !tt.ptr>, i32 + } do { + ^bb0(%arg3: !tt.ptr>, %arg4: i32): + %2 = tt.load %arg0 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %3 = tt.advance %arg3, [%c256_i32, %c0_i32] : > + %4 = tt.load %3 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %5 = tt.trans %4 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %6 = tt.dot %2, %5, %cst, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %7 = arith.addi %arg4, %c32_i32 : i32 + scf.yield %3, %7 : !tt.ptr>, i32 + } + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans4 + // CHECK: tt.trans +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 40a3b3430b..4c41d63ed3 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -6,6 +6,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" @@ -16,6 +17,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include #define DEBUG_TYPE "tritonintelgpu-optimize-dot-operands" @@ -179,11 +181,12 @@ class FuseTransWithLoad { if (op->hasOneUse()) return true; - if (!op->getParentOfType()) + if (!op->getParentOfType()) return false; - auto forOp = op->getParentOfType(); - auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto loopOp = op->getParentOfType(); + auto yieldOp = cast( + loopOp.getYieldedValues()[0].getParentBlock()->getTerminator()); SmallVector users(op->getUsers()); if (users.size() > 2 || llvm::none_of(users, [&](Operation *user) { @@ -198,8 +201,8 @@ class FuseTransWithLoad { }); assert(it != yieldOp->getOpOperands().end()); OpOperand &operand = *it; - auto forOp = cast(yieldOp->getParentOp()); - OpResult res = forOp->getResult(operand.getOperandNumber()); + auto loopOp = cast(yieldOp->getParentOp()); + OpResult res = loopOp->getResult(operand.getOperandNumber()); return !res.getUsers().empty(); }; @@ -225,11 +228,14 @@ class FuseTransWithLoad { continue; } + if (isa(user)) + return false; + // Find the next operation in the def-use chain inside the loop body. - if (auto forOp = dyn_cast(user)) { - for (BlockArgument arg : forOp.getRegionIterArgs()) { - Value initArg = forOp.getInitArgs()[arg.getArgNumber() - 1]; - if (initArg == currentOp->getResult(0)) { + if (auto loopOp = dyn_cast(user)) { + for (auto [arg, init] : + llvm::zip(loopOp.getRegionIterArgs(), loopOp.getInits())) { + if (init == currentOp->getResult(0)) { if (!arg.hasOneUse()) return false; @@ -327,16 +333,15 @@ class FuseTransWithLoad { // Update the yield's parent operation result type. Operation *parentOp = yieldOp->getParentOp(); - for (OpResult res : parentOp->getOpResults()) { - int resNum = res.getResultNumber(); - if (resNum == opNum) - res.setType(newVal.getType()); - } + OpResult res = parentOp->getOpResult(opNum); + res.setType(newVal.getType()); return; } if (auto forOp = dyn_cast(user)) return propagateToLoop(newVal, origVal, forOp, sentinel); + + llvm_unreachable("Unexpected kind of user"); } void propagateToLoop(Value newVal, Value origVal, LoopLikeOpInterface loopOp, From a22e80a64aa0089eaadd989cdc6edadc7943f3d3 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 19 Jun 2025 20:36:00 +0000 Subject: [PATCH 14/14] Address code review comments Signed-off-by: Tiotto, Ettore --- .../lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 4c41d63ed3..bff9f17673 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -231,6 +231,8 @@ class FuseTransWithLoad { if (isa(user)) return false; + [[maybe_unused]] Operation *oldCurrentOp = currentOp; + // Find the next operation in the def-use chain inside the loop body. if (auto loopOp = dyn_cast(user)) { for (auto [arg, init] : @@ -244,6 +246,8 @@ class FuseTransWithLoad { } } } + + assert(currentOp != oldCurrentOp && "Infinite loop detected!"); } return true;