-
Notifications
You must be signed in to change notification settings - Fork 68
Convert block ptr type layouts to Subgroup2DBlockEncoding
layouts
#4463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
64b1a7f
34f6aa5
a53f8d0
80c41d7
272fc3e
1e364c3
d9f8e6a
716f589
3241805
c117b37
0cf5435
56bc075
8a9b5b4
7cbe3fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s | ||
|
||
// COM: test complete example | ||
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> | ||
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> | ||
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> | ||
// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> | ||
// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> | ||
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { | ||
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) { | ||
%c4_i32 = arith.constant 4 : i32 | ||
%c256_i32 = arith.constant 256 : i32 | ||
%c1024_i64 = arith.constant 1024 : i64 | ||
%c5120_i64 = arith.constant 5120 : i64 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c4096_i64 = arith.constant 4096 : i64 | ||
%c32_i32 = arith.constant 32 : i32 | ||
%c64_i32 = arith.constant 64 : i32 | ||
%c5120_i32 = arith.constant 5120 : i32 | ||
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> | ||
|
||
// CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]>> | ||
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>> | ||
// CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c256_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked2>> | ||
// CHECK: %[[RES:.*]]:3 = scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]]) | ||
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>) : i32 { | ||
%17 = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>> | ||
// CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]>> | ||
// CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<256x32xf16, #blocked1> | ||
%18 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked2>> | ||
// CHECK: %[[B_LOAD:.*]] = tt.load %[[ARG6]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
// CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<32x256xf16, #blocked2> | ||
%19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> | ||
%20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> | ||
%21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> | ||
%22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> | ||
%23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | ||
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> | ||
%24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> | ||
%25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> | ||
// CHECK: %[[ADVANCE_A:.*]] = tt.advance {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]>> | ||
%26 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>> | ||
// CHECK: %[[ADVANCE_B:.*]] = tt.advance {{.*}} : <tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
%27 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked2>> | ||
// CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]] | ||
scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>> | ||
alexbaden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c256_i32] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked2>> | ||
// CHECK aritch.truncf %[[RES]]#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> | ||
%15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add CHECK to verify that %13#0 has the dpas layout. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It won't have the DPAS layout, it will be blocked. The DPAS layout is converted to blocked in #25. |
||
%16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> | ||
tt.store %14, %16 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked2>> | ||
tt.return | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is a test containing an scf.while loop:
This is currently failing, I'd expect this to work given that the pass contains some code to handle while loops.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is probably failing because the advance occurs before the load and the argument to advance is not |
||
|
||
// ----- | ||
|
||
// COM: Test while loop / nested tt.advance | ||
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> | ||
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> | ||
// CHECK-DAG: #[[$SUBGROUP_2D_BLOCK:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> | ||
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { | ||
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>) { | ||
%c1024_i64 = arith.constant 1024 : i64 | ||
%c5120_i64 = arith.constant 5120 : i64 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%c256_i32 = arith.constant 256 : i32 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c32_i32 = arith.constant 32 : i32 | ||
|
||
// CHECK: %[[A_PTR:.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> | ||
%a_ptr = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>> | ||
|
||
// CHECK: scf.while {{.*}} : (!tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>) -> !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> | ||
%1 = scf.while (%a_ptr_crt = %a_ptr) : (!tt.ptr<tensor<256x32xf16, #blocked1>>) -> (!tt.ptr<tensor<256x32xf16, #blocked1>>) { | ||
%2 = "dummy.evaluate_condition"() : () -> i1 | ||
// CHECK: scf.condition({{.*}}) {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> | ||
scf.condition(%2) %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>> | ||
} do { | ||
^bb0(%a_ptr_crt: !tt.ptr<tensor<256x32xf16, #blocked1>>): | ||
// CHECK: ^bb0({{.*}}: !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>): | ||
|
||
// CHECK: %[[A_LOAD:.*]] = tt.load {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> | ||
%3 = tt.load %a_ptr_crt {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>> | ||
// CHECK: ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]> -> tensor<256x32xf16, #[[$BLOCKED]]> | ||
// CHECK: ttg.convert_layout {{.*}} : tensor<256x32xf16, #[[$BLOCKED]]> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> | ||
%4 = ttg.convert_layout %3 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> | ||
|
||
%cstB = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | ||
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> | ||
|
||
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> | ||
%5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> | ||
%6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1> | ||
// CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> | ||
%7 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>> | ||
|
||
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>> | ||
scf.yield %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>> | ||
} | ||
tt.return | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// COM: test complex control flow | ||
// COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if. | ||
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> | ||
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> | ||
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> | ||
// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> | ||
// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> | ||
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { | ||
// CHECK-LABEL: @matmul_change_block_ptr_in_prologue | ||
tt.func @matmul_change_block_ptr_in_prologue(%a_base: !tt.ptr<f16>, | ||
%b_base: !tt.ptr<f16>) { | ||
%c0_i64 = arith.constant 0 : i64 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%k_tiles = arith.constant 32 : i64 | ||
%true = arith.constant true | ||
%false = arith.constant false | ||
|
||
%zero = arith.constant dense<0.0> : tensor<128x128xf32, #blocked> | ||
|
||
// CHECK: %[[A_UNDEF:.*]] = ub.poison : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>> | ||
// CHECK: %[[B_UNDEF:.*]] = ub.poison : !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
%a_ptr_undef = ub.poison : !tt.ptr<tensor<128x64xf16, #blocked1>> | ||
%b_ptr_undef = ub.poison : !tt.ptr<tensor<64x128xf16, #blocked2>> | ||
// CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[A_PTR:.*]] = %[[A_UNDEF]], %[[B_PTR:.*]] = %[[B_UNDEF]]) | ||
scf.for %k = %c0_i64 to %k_tiles step %c1_i64 iter_args(%acc = %zero, %flag = %true, %a_ptr = %a_ptr_undef, %b_ptr = %b_ptr_undef) -> (tensor<128x128xf32, #blocked>, i1, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>) : i64 { | ||
%do_prologue = "prologue_cond"(%k) : (i64) -> i1 | ||
// CHECK: %[[PTRS:.*]]:2 = scf.if {{.*}} -> (!tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>>) | ||
%cur_a_ptr, %cur_b_ptr = scf.if %do_prologue -> (!tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>>) { | ||
%off_m, %off_n, %off_k = "get_offsets"(%k) : (i64) -> (i32, i32, i32) | ||
// CHECK tt.make_tensor_ptr {{.*}} : <tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>> | ||
%next_a_ptr = tt.make_tensor_ptr %a_base, [%k, %k], [%c1_i64, %c1_i64], [%off_m, %off_k] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>> | ||
// CHECK tt.make_tensor_ptr {{.*}} : <tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
%next_b_ptr = tt.make_tensor_ptr %b_base, [%k, %k], [%c1_i64, %c1_i64], [%off_n, %off_k] {order = array<i32: 1, 0>} : <tensor<64x128xf16, #blocked2>> | ||
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
scf.yield %next_a_ptr, %next_b_ptr : !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>> | ||
} else { | ||
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
scf.yield %a_ptr, %b_ptr : !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>> | ||
} | ||
|
||
// CHECK: %[[A:.*]] = tt.load %[[PTRS]]#0 {{.*}} : !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>> | ||
%a = tt.load %cur_a_ptr {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>> | ||
// CHECK: ttg.convert_layout %[[A]] : tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<128x64xf16, #blocked1> | ||
// CHECK: %[[B:.*]] = tt.load %[[PTRS]]#1 {{.*}} : !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
%b = tt.load %cur_b_ptr {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x128xf16, #blocked2>> | ||
// CHECK: {{.*}} = ttg.convert_layout %[[B]] : tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<64x128xf16, #blocked2> | ||
%a_dot = ttg.convert_layout %a : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> | ||
%b_dot = ttg.convert_layout %b : tensor<64x128xf16, #blocked2> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> | ||
%a_dot_dpas = ttg.convert_layout %a_dot : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> | ||
%b_dot_dpas = ttg.convert_layout %b_dot : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | ||
%accum = ttg.convert_layout %acc : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> | ||
%c = tt.dot %a_dot_dpas, %b_dot_dpas, %accum, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> | ||
%c_out = ttg.convert_layout %c : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> | ||
|
||
%do_epilogue = arith.cmpi eq, %k, %c0_i64 : i64 | ||
%use_acc = arith.select %do_epilogue, %false, %true : i1 | ||
scf.if %do_epilogue { | ||
"acc_user"(%c_out) : (tensor<128x128xf32, #blocked>) -> () | ||
} | ||
// CHECK: scf.yield {{.*}} : {{.*}}, i1, !tt.ptr<tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]>>, !tt.ptr<tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]>> | ||
scf.yield %c_out, %use_acc, %cur_a_ptr, %cur_b_ptr : tensor<128x128xf32, #blocked>, i1, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<64x128xf16, #blocked2>> | ||
} | ||
|
||
tt.return | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -409,4 +409,22 @@ def TritonIntelGPUReduceVariableLiveness | |
"mlir::scf::SCFDialect", | ||
"mlir::arith::ArithDialect"]; | ||
} | ||
|
||
def TritonIntelGPUOptimizeBlockIOEncodingPass | ||
: Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { | ||
let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; | ||
|
||
let description = [{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, currently a |
||
Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. | ||
|
||
The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the | ||
encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a | ||
ConvertLayout op to the existing encoding replaces the result of the LoadOp. | ||
}]; | ||
|
||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", | ||
"mlir::triton::gpu::intel::TritonIntelGPUDialect", | ||
"mlir::triton::TritonDialect"]; | ||
} | ||
|
||
#endif // TRITON_INTEL_GPU_PASSES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a CHECK to ensure that the loop init args have the new tensor ptr type with %mma layout.
The return types of the
scf.for
operation should be CHECKED too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type of the scf is not modified - so I am not sure what we need to check? The
tt.dot
return type is unchanged and the make tensor ptr / store uses the same layout as before.I can add a check to make sure the arguments of the scf for loop are using the new values, but the verifier for the scf for should not allowed a type mismatch so I don't think we need another type check (I briefly looked through other lit examples and did not see arg type checking for
scf.for
).