Skip to content

[optimize-dot-operands]: Fuse load and trans operations - part 2 #4468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 20, 2025

Conversation

etiotto
Copy link
Contributor

@etiotto etiotto commented Jun 9, 2025

This PR enhances the new transformation pass aimed at fusing tt.load and tt.trans operations. Specifically it adds support for loop carried arguments used (possibly transitively) by the candidate tt.load that should be fused with a tt.trans.
Example:

    %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>>
    %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
      %17 = tt.advance %9, [%c256_i32, %arg5] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
      %18 = tt.load %17 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
      %19 = tt.advance %arg6, [%c16_i32, %arg5] : <tensor<256x32xbf16, #linear>>
      %20 = tt.load %19 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xbf16, #linear>>
      %21 = tt.trans %20 {order = array<i32: 1, 0>} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
      %23 = arith.addi %arg5, %c32_i32 : i32
      scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>
    }

Here the load %20 is a candidate for fusion with the tt.trans operation. The pointer argument used by the candidate load (%19) is produced by a tt.advance operation which uses the loop carried pointer %arg6.

etiotto added 8 commits June 4, 2025 21:01
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
@etiotto etiotto self-assigned this Jun 9, 2025
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
@etiotto etiotto marked this pull request as ready for review June 10, 2025 19:40
…tt.dot

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
@etiotto etiotto requested a review from a team June 17, 2025 15:44
@alexbaden
Copy link
Contributor

If I understand the problem correctly, we are trying to rewrite the original MakeTensorPtr load to be transposed because we cannot "see through" the trans op when lowering our Load Op from TTGPU IR to LLVM.
I had a similar problem where I needed to rewrite the type of a MakeTensorPtr and struggled to think up all the different combinations of IR (e.g. Whitney's comment above, also the case where you have a MakeTensorPtr with multiple descendant loads, some transposed and some not, etc). I did some poking around and on the advice of Jeff from OpenAI looked at the NVIDIA pass to optimize tensor descriptor encodings - https://github.com/intel/intel-xpu-backend-for-triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp. This pass is doing roughly the same thing we want to do - change the type of a def-use chain - but involving tensor descriptors instead of tensor pointers. However, I was able to adapt it in #4463 to apply to tensor ptr to change the layout of a tensor ptr def-use chain from blocked to Subgroup 2D Block IO. The advantage to changing the type directly is you only have to handle the specific control flow operators (yield, for, if, etc) in isolation - you don't have to worry about downstream effects (because you are not rewriting the operator).

That being said I wonder if we could resolve the original problem without changing the tensor ptr type. It might not be possible now since the only information carried about 2D block encoding is in the string attribute, but I think it might be possible to convey this info with the Subgroup 2D Block encoding layout in the future.

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
etiotto added 3 commits June 19, 2025 17:23
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
@etiotto etiotto merged commit 4bc28e2 into main Jun 20, 2025
15 checks passed
@etiotto etiotto deleted the etiotto.merge_load_with_trans.2 branch June 20, 2025 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TransOp fusion]: Fuse tt.trans with tt.load to expoit 2D block read operations
3 participants