Skip to content

Convert block ptr type layouts to Subgroup2DBlockEncoding layouts #4463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

alexbaden
Copy link
Contributor

@alexbaden alexbaden commented Jun 9, 2025

To ensure Subgroup2DBlockIO layouts survive RemoveLayoutConversions we need to identify them as anchor layouts on the LoadOp. To do that, we need to change the load layout to be a Subgroup2DBlockIO layout and propagate that layout as necessary through the existing IR. This PR introduces TritonIntelGPU::OptimizeBlockEncoding which replaces the encoding for all ptr types between the candidate LoadOp and the MakeTensorPtr Op which creates the ptr for the load. The pass is based on a similar pass from upstream, TritonNVIDIAGPU::OptimizeDescriptorEncoding, which modifies layouts in-place for tensor descriptors. To avoid modifying any non-ptr types, which should be modified using ConvertLayoutOp, we add a dummy layout conversion after the LoadOp which will be removed when the old load layout is removed in RemoveLayoutConversions.

With this pass we are able to create the desired chain of Subgroup2DBlockIO -> ConvertLayout -> DPAS which will allow us to begin using Subgroup2DBlockIO in the LLVM lowering and drive the layout conversion via LinearLayout objects. The pass is introduced in this PR unused, but I did include a lit test which covers the most common for loop use case. I intend to improve test coverage as we expand use of the pass.

Depends on #4461

Close #4362

@alexbaden alexbaden force-pushed the alex/optimize_block_encoding branch 3 times, most recently from 6d8d666 to 098c6c8 Compare June 12, 2025 21:28
@alexbaden alexbaden marked this pull request as ready for review June 12, 2025 21:48
@alexbaden alexbaden force-pushed the alex/optimize_block_encoding branch 2 times, most recently from b86b08f to 9b56562 Compare June 13, 2025 01:53
alexbaden added a commit that referenced this pull request Jun 17, 2025
Adds the end-to-end dot product on block ptr testing to the block load
unit test (maybe it should be renamed `test_block_ptr.py`?). Adds
additional shapes, A transpose, and B transpose. The cold runtime (no
cache) is approximately 1 minute on PVC 1100 in my environment. I picked
the block shapes somewhat randomly, trying to balance breadth and
runtime.

This somewhat duplicates tutorial 10 but allows us to run many more
combinations in shorter time. I added this because #4463 is passing CI
but has a few bugs that are not being caught by existing unit tests,
including tutorials.

---------

Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
@alexbaden alexbaden force-pushed the alex/optimize_block_encoding branch from 9b56562 to 09fec1e Compare June 17, 2025 18:21
def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> {
let summary = "Set encodings on candidates for Subgroup 2D Block IO ops";

let description = [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: tt.load operations feeding a tt.dot operation are expected to have "dot" encoding in their type. How does this work with the new encoding ? Will we need a layout conversion operation to change the layout from "BlockEncoding" to "Dot" encoding ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, currently a tt.load op is expected to have DotOperandEncoding and the block_io tag if we want to lower it to Subgroup 2D Block loads. This pass changes the type of the tt.load to have a Subgroup2DBlockEncoding layout but inserts a layout conversion back to the original type. Because the Subgroup2DBlockEncoding layout is not an anchor layout it is replaced in RemoveLayoutConversions with the DotOperandEncoding layout.
In #4500 I make Subgroup2DBLockEncoding an anchor layout and teach LoadStoreOpToLLVM to "see through" the layout conversion, which becomes a no-op. But the ultimate goal (which I am working on now) is to move the layout conversion (from subgroup 2d block load to DPAS in registers) out of LoadStoreOpToLLVM and into a ConvertLayoutOp, so eventually LoadStoreOpToLLVM would not depend directly on DPAS to lower block ptr / tensor desc loads to 2d block IO instructions.

@etiotto
Copy link
Contributor

etiotto commented Jun 18, 2025

@chengjunlu have you taken an initial look at this proposed pass ? Opinions?

Copy link
Contributor

@etiotto etiotto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add lit test that exercise and scf.while loop and a if statement as well.

Comment on lines +9 to +12
namespace mlir {
namespace triton {
namespace gpu::intel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
namespace mlir {
namespace triton {
namespace gpu::intel {
namespace mlir::triton::gpu::intel {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

class TritonIntelGPUOptimizeBlockIOEncodingPass
: public impl::TritonIntelGPUOptimizeBlockIOEncodingPassBase<
TritonIntelGPUOptimizeBlockIOEncodingPass> {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Suggestion]: Can you put the public members at the top and the private members at the bottom of the class. That way the public interface is upfront.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed like upstream passes put runOnOperation at the bottom of the file by convention - so when you want to know what a pass does, you start by scrolling all the way down and read bottom to top. I tried to follow that, but if we want to standardize and write down a different convention I think that would be fine.

// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the module declaration be simplified to:
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I removed some of the ones but left a combination of attributes that I believe are relevant and attributes that are required.

%11 = arith.muli %8, %c256_i32 : i32
// CHECK: tt.make_tensor_ptr {{.*}} : <tensor<32x256xf16, #mma1>>
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked2>>
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>) : i32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a CHECK to ensure that the loop init args have the new tensor ptr type with %mma layout.
The return types of the scf.for operation should be CHECKED too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of the scf is not modified - so I am not sure what we need to check? The tt.dot return type is unchanged and the make tensor ptr / store uses the same layout as before.

 %26 = tt.dot %24, %25, %23, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2>
      %27 = ttg.convert_layout %26 : tensor<256x256xf32, #mma2> -> tensor<256x256xf32, #blocked>
      %28 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #mma>>
      %29 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #mma1>>
      scf.yield %27, %28, %29 : tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #mma>>, !tt.ptr<tensor<32x256xf16, #mma1>>
    }
    %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked2>>
    %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
    %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2>
    tt.store %14, %16 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked2>>
    tt.return

I can add a check to make sure the arguments of the scf for loop are using the new values, but the verifier for the scf for should not allowed a type mismatch so I don't think we need another type check (I briefly looked through other lit examples and did not see arg type checking for scf.for).

scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>
}
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked2>>
%15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add CHECK to verify that %13#0 has the dpas layout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't have the DPAS layout, it will be blocked. The DPAS layout is converted to blocked in #25.

namespace {

SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code for scf.for and scf.while can be commoned by dynamic casting op to a LoopLikeOpInterface, which both operations implement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to access op specific methods, don't we?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoopLikeOpInterface should be get loop info with common name.

auto yieldVal =
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
auto initVal = whileOp.getOperands()[resultIdx];
return {iterArg, result, iterArg, initVal};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a return there is no need for and else if, can you pls change it to an if.

@alexbaden
Copy link
Contributor Author

Hi @etiotto we just discussed in the architecture meeting refactoring SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) into a common utility since it is duplicated from the upstream pass - https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp#L131. So, I would prefer not to make cosmetic changes that might be controversial getting merged upstream.

tt.store %14, %16 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked2>>
tt.return
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a test containing an scf.while loop:

#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
  tt.func public @test_while_loop(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: i1, %B: tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>)  {
    %c0_i32 = arith.constant 0 : i32
    %c1_i64 = arith.constant 1 : i64
    %c32_i32 = arith.constant 32 : i32
    %c256_i32 = arith.constant 256 : i32
    %c1024_i64 = arith.constant 1024 : i64
    %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
    %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>>
    %1:2 = scf.while (%arg3 = %0, %arg4 = %c0_i32) : (!tt.ptr<tensor<256x32xf16, #blocked1>>, i32) -> (!tt.ptr<tensor<256x32xf16, #blocked1>>, i32) {
      scf.condition(%arg2) %arg3, %arg4 : !tt.ptr<tensor<256x32xf16, #blocked1>>, i32
    } do {
    ^bb0(%arg3: !tt.ptr<tensor<256x32xf16, #blocked1>>, %arg4: i32):
      %3 = tt.advance %arg3, [%c256_i32, %c0_i32] : <tensor<256x32xf16, #blocked1>>
      %A = tt.load %3 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
      %convA1 = ttg.convert_layout %A : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
      %convA2 = ttg.convert_layout %convA1 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
      %res = tt.dot %convA2, %B, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
      %7 = arith.addi %arg4, %c32_i32 : i32
      scf.yield %3, %7 : !tt.ptr<tensor<256x32xf16, #blocked1>>, i32
    }
    tt.return
  }
}

This is currently failing, I'd expect this to work given that the pass contains some code to handle while loops.

PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: /home/jovyan/intel-xpu-backend-for-triton/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt /home/jovyan/tmp/test4.mlir --split-input-file --tritonintelgpu-optimize-block-io-encoding
 #0 0x00005633d27e59b7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/home/jovyan/intel-xpu-backend-for-triton/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x483d9b7)
 #1 0x00005633d27e361e llvm::sys::RunSignalHandlers() (/home/jovyan/intel-xpu-backend-for-triton/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x483b61e)
 #2 0x00005633d27e60b5 SignalHandler(int, siginfo_t*, void*) Signals.cpp:0:0
 #3 0x00007fbcfefb1520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007fbcff0059fc pthread_kill (/lib/x86_64-linux-gnu/libc.so.6+0x969fc)
 #5 0x00007fbcfefb1476 gsignal (/lib/x86_64-linux-gnu/libc.so.6+0x42476)
 #6 0x00007fbcfef977f3 abort (/lib/x86_64-linux-gnu/libc.so.6+0x287f3)
 #7 0x00005633d2793231 (/home/jovyan/intel-xpu-backend-for-triton/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x47eb231)
 #8 0x00005633ce562bf1 mlir::triton::getMakeTensorPtrOp(mlir::Value) /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/Triton/IR/Utility.cpp:93:1
 #9 0x00005633ce562639 getMakeTensorPtrOpImpl(mlir::Operation*, mlir::Value) /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/Triton/IR/Utility.cpp:30:34
#10 0x00005633ce562989 mlir::triton::getMakeTensorPtrOp(mlir::Value) /home/jovyan/intel-xpu-backend-for-triton/lib/Dialect/Triton/IR/Utility.cpp:69:34
#11 0x00005633cebcf32e mlir::triton::gpu::intel::TritonIntelGPUOptimizeBlockIOEncodingPass::getSubgroup2DBlockLayoutForOperand(mlir::Value, mlir::triton::gpu::intel::DpasEncodingAttr, llvm::MapVector<mlir::Operation*, mlir::Attribute, llvm::DenseMap<mlir::Operation*, unsigned int, llvm::DenseMapInfo<mlir::Operation*, void>, llvm::detail::DenseMapPair<mlir::Operation*, unsigned int>>, llvm::SmallVector<std::pair<mlir::Operation*, mlir::Attribute>, 0u>>&) /home/jovyan/intel-xpu-backend-for-triton/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp:245:57
#12 0x00005633cebcf8a3 mlir::triton::gpu::intel::TritonIntelGPUOptimizeBlockIOEncodingPass::runOnOperation()::'lambda'(mlir::triton::DotOp)::operator()(mlir::triton::DotOp) const /home/jovyan/intel-xpu-backend-for-triton/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp:300:41
#13 0x00005633cebd253f _ZZN4mlir6detail4walkILNS_9WalkOrderE1ENS_15ForwardIteratorEZNS_6triton3gpu5intel41TritonIntelGPUOptimizeBlockIOEncodingPass14runOnOperationEvEUlNS4_5DotOpEE_S8_vEENSt9enable_ifIXaantsrSt11disjunctionIJSt7is_sameIT2_PNS_9OperationEESC_ISD_PNS_6RegionEESC_ISD_PNS_5BlockEEEE5valuesrSC_IT3_vE5valueESO_E4typeESF_OT1_ENKUlSF_E_clESF_ /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Visitors.h:336:20
#14 0x00005633cebd3c98 _ZN4llvm12function_refIFvPN4mlir9OperationEEE11callback_fnIZNS1_6detail4walkILNS1_9WalkOrderE1ENS1_15ForwardIteratorEZNS1_6triton3gpu5intel41TritonIntelGPUOptimizeBlockIOEncodingPass14runOnOperationEvEUlNSB_5DotOpEE_SF_vEENSt9enable_ifIXaantsrSt11disjunctionIJSt7is_sameIT2_S3_ESJ_ISK_PNS1_6RegionEESJ_ISK_PNS1_5BlockEEEE5valuesrSJ_IT3_vE5valueEST_E4typeES3_OT1_EUlS3_E_EEvlS3_ /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/llvm/ADT/STLFunctionalExtras.h:47:40
#15 0x00005633ce354d47 llvm::function_ref<void (mlir::Operation*)>::operator()(mlir::Operation*) const /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/llvm/ADT/STLFunctionalExtras.h:69:62
#16 0x00005633ce352580 void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Visitors.h:187:1
#17 0x00005633ce3524ef void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Visitors.h:179:7
#18 0x00005633ce3524ef void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Visitors.h:179:7
#19 0x00005633ce3524ef void mlir::detail::walk<mlir::ForwardIterator>(mlir::Operation*, llvm::function_ref<void (mlir::Operation*)>, mlir::WalkOrder) /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Visitors.h:179:7
#20 0x00005633cebd25b4 _ZN4mlir6detail4walkILNS_9WalkOrderE1ENS_15ForwardIteratorEZNS_6triton3gpu5intel41TritonIntelGPUOptimizeBlockIOEncodingPass14runOnOperationEvEUlNS4_5DotOpEE_S8_vEENSt9enable_ifIXaantsrSt11disjunctionIJSt7is_sameIT2_PNS_9OperationEESC_ISD_PNS_6RegionEESC_ISD_PNS_5BlockEEEE5valuesrSC_IT3_vE5valueESO_E4typeESF_OT1_ /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Visitors.h:341:38
#21 0x00005633cebd12f0 _ZN4mlir9Operation4walkILNS_9WalkOrderE1ENS_15ForwardIteratorEZNS_6triton3gpu5intel41TritonIntelGPUOptimizeBlockIOEncodingPass14runOnOperationEvEUlNS4_5DotOpEE_vEENSt9enable_ifIXeqsrN4llvm15function_traitsINSt5decayIT1_E4typeEXsrSt8is_classISG_E5valueEEE8num_argsLi1EET2_E4typeEOSE_ /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/Operation.h:798:75
#22 0x00005633cebd019f _ZN4mlir7OpState4walkILNS_9WalkOrderE1ENS_15ForwardIteratorEZNS_6triton3gpu5intel41TritonIntelGPUOptimizeBlockIOEncodingPass14runOnOperationEvEUlNS4_5DotOpEE_vEENSt9enable_ifIXeqsrN4llvm15function_traitsINSt5decayIT1_E4typeEXsrSt8is_classISG_E5valueEEE8num_argsLi1EET2_E4typeEOSE_ /home/jovyan/.triton/llvm/llvm-8957e64a-ubuntu-x64/include/mlir/IR/OpDefinition.h:169:68
#23 0x00005633cebcf962 mlir::triton::gpu::intel::TritonIntelGPUOptimizeBlockIOEncodingPass::runOnOperation() /home/jovyan/intel-xpu-backend-for-triton/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp:311:21
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is probably failing because the advance occurs before the load and the argument to advance is not MakeTensorPtrOp but the while loop construction. Why do you expect this to work? Can this ever be generated from a Triton kernel written in Python? I have run this PR on all the tests and benchmarks and not seen this error - so if it legal then we are missing unit test coverage for this pattern.

@alexbaden alexbaden force-pushed the alex/optimize_block_encoding branch from 09fec1e to 37e5f75 Compare June 23, 2025 19:13
namespace {

SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoopLikeOpInterface should be get loop info with common name.

<< itr->second.desiredEncoding << " for value "
<< typedVal << ". Ensure new encoding "
<< info.desiredEncoding << " matches.\n");
assert(itr->second == info && "already visited encoding info for "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the assertion is not triggered as only a subset of uses kind are considered.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but if a use is missed wouldn't we see the IR fail to verify?

@alexbaden
Copy link
Contributor Author

FYI getTiedArgs was refactored upstream into shared utils, so I will remove it and use the shared impl once we merge in that change: triton-lang/triton@4d791f0

@alexbaden
Copy link
Contributor Author

I added a test with more complex control flow - scf.if, scf.for, and scf.yield. I borrowed it from the warp specialization test and I don't think it makes much sense from a kernel perspective (maybe better with a tensor descriptor) - but, it has the nice property that the MakeTensorPtr op is inside the scf.if which tests our ability to find the MakeTensorPtr op inside complex control flow, then rewrite it.

I also wanted to hunt down the source of the crash @etiotto posted above - it turns out that it is a result of upstream code that assumes all tt.ptr can find MakeTensorPtr - https://github.com/intel/intel-xpu-backend-for-triton/blob/main/lib/Dialect/Triton/IR/Utility.cpp#L92. Maybe this is worth changing, but upstream is still running the RewriteTensorPtr pass so presumably all kernels should be well-formed and hitting this assert is truly an error.

@alexbaden
Copy link
Contributor Author

Thanks @etiotto and @whitneywhtsang for your comments - I have increased the testing to show more structured control flow examples to include if, while, and for loops. While doing this I discovered a bug in the getTiedArgs function which I fixed, and which is also merged upstream. I also fixed an issue in getMakeTensorPtr for while loops which I have PRed separately (#4567) but included and tested here (I think it would be nice to have that commit be logically separate). Finally, I added support for tt.advance in nested control flow. I believe this should address all the remaining comments above.

alexbaden added a commit that referenced this pull request Jun 25, 2025
Properly handle while loop in `getMakeTensorPtr`. 

This is common code, but given upstream is deprecating the feature I
don't think they want to take this change. The test for this is in #4463
but I think it is better to keep the commit logically separate and not
have it squashed and merged with #4463.
@alexbaden alexbaden force-pushed the alex/optimize_block_encoding branch from 05dab1c to a915a90 Compare June 25, 2025 13:03
@whitneywhtsang
Copy link
Contributor

FYI getTiedArgs was refactored upstream into shared utils, so I will remove it and use the shared impl once we merge in that change: triton-lang/triton@4d791f0

Merged to our repo.

@alexbaden
Copy link
Contributor Author

alexbaden commented Jun 26, 2025

We need the bug fix too: triton-lang/triton@e71689d
but I don't think we need to hold this PR for those merges.

oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding);

val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace()));
if (einfo.requiresConvert) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the scenario that requiresConvert is false, and there exists a user of type LoadOp? wonder if requiresConvert is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of the algorithm in the pass is to attach new encoding info to each value with the initial propagation and then update all values in the second pass. It seemed like requiresConvert fit more naturally in the first phase, where the encoding info is being determined and propagated. But, I think you are right that we can look at all users of each Value in the second phase and add the convert if any one of the users is a load.

I think I prefer storing requiresConvert in the first phase - both because it seems to fit more naturally there and because it serves as an additional sanity check on the algorithm (if we see the same Value again, we check to make sure requiresConvert matches the last time we saw that Value). I suppose we could store the index of the User in requiresConvert instead of making it a boolean, and avoid the second Users loop in phase 2, but the users loop doesn't seem too expensive. I also think we could extend the use of this pass to propagate transpose layouts, too, so having requiresConvert, or at least having the machinery to process additional properties, may be useful.

@alexbaden alexbaden force-pushed the alex/optimize_block_encoding branch from a915a90 to fa5910c Compare June 27, 2025 00:48
@alexbaden
Copy link
Contributor Author

For some reason I can't reply to your tt.advance comment but yes, I added explicit processing for tt.advance because the result of tt.advance was not being updated. The operand was being updated as part of the existing control flow handling. If tt.advance was being used by an op other than the LoadOp being targeted as part of this pass then we would see a failure, but that failure would be caught during the pass pipeline and up-leveled as a type mismatch which would be pretty easy to detect. In all the benchmarks and unit tests I have run I have not seen such a scenario.

@whitneywhtsang
Copy link
Contributor

We need the bug fix too: triton-lang/triton@e71689d but I don't think we need to hold this PR for those merges.

FYI merged too.

Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ARs captured from the architecture meeting:

  • Rebase to utilize the upstream getTiedArgs implementation.
  • Remove the requiresConvert field.
  • Add a lit test for the case where the result of an advance op is utilized by the load op.
  • Ensure that the performance of tutorial 6 remains unaffected by these changes.

@whitneywhtsang whitneywhtsang force-pushed the alex/optimize_block_encoding branch 2 times, most recently from 778c666 to a43cbe7 Compare July 14, 2025 17:47
@whitneywhtsang whitneywhtsang force-pushed the alex/optimize_block_encoding branch from a43cbe7 to 7cbe3fe Compare July 14, 2025 17:48
@etiotto etiotto marked this pull request as draft July 15, 2025 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Introduce a new pass to change LoadOp layouts to Subgroup2DBlock layouts
3 participants