Skip to content

Commit 09fec1e

Browse files
committed
Add pass to convert block load to subgroup 2d block encoding types
1 parent bb5a79a commit 09fec1e

File tree

6 files changed

+400
-0
lines changed

6 files changed

+400
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
5+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
6+
// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}>
7+
// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}>
8+
// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
9+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
10+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
11+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) attributes {noinline = false} {
12+
%c4_i32 = arith.constant 4 : i32
13+
%c256_i32 = arith.constant 256 : i32
14+
%c1024_i64 = arith.constant 1024 : i64
15+
%c5120_i64 = arith.constant 5120 : i64
16+
%c1_i64 = arith.constant 1 : i64
17+
%c0_i32 = arith.constant 0 : i32
18+
%c4096_i64 = arith.constant 4096 : i64
19+
%c32_i32 = arith.constant 32 : i32
20+
%c64_i32 = arith.constant 64 : i32
21+
%c5120_i32 = arith.constant 5120 : i32
22+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked>
23+
%0 = tt.get_program_id x : i32
24+
%1 = arith.divsi %0, %c64_i32 : i32
25+
%2 = arith.muli %1, %c4_i32 : i32
26+
%3 = arith.subi %c4_i32, %2 : i32
27+
%4 = arith.minsi %3, %c4_i32 : i32
28+
%5 = arith.remsi %0, %4 : i32
29+
%6 = arith.addi %2, %5 : i32
30+
%7 = arith.remsi %0, %c64_i32 : i32
31+
%8 = arith.divsi %7, %4 : i32
32+
%9 = arith.muli %6, %c256_i32 : i32
33+
// CHECK: tt.make_tensor_ptr {{.*}} : <tensor<256x32xf16, #mma>>
34+
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>>
35+
%11 = arith.muli %8, %c256_i32 : i32
36+
// CHECK: tt.make_tensor_ptr {{.*}} : <tensor<32x256xf16, #mma1>>
37+
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked2>>
38+
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>) : i32 {
39+
%17 = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
40+
// CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #mma>>
41+
// CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1>
42+
%18 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked2>>
43+
// CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #mma1>>
44+
// CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2>
45+
%19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
46+
%20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
47+
%21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma>
48+
%22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
49+
%23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
50+
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2>
51+
%24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
52+
%25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
53+
// CHECK: tt.advance {{.*}} : <tensor<256x32xf16, #mma>>
54+
%26 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
55+
// CHECK: tt.advance {{.*}} : <tensor<32x256xf16, #mma1>>
56+
%27 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked2>>
57+
scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr<tensor<256x32xf16, #blocked1>>, !tt.ptr<tensor<32x256xf16, #blocked2>>
58+
}
59+
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #blocked2>>
60+
%15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
61+
%16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2>
62+
tt.store %14, %16 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #blocked2>>
63+
tt.return
64+
}
65+
}

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties):
280280

281281
intel.passes.ttgpuir.add_accelerate_matmul(pm)
282282
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
283+
intel.passes.ttgpuir.add_optimize_block_load_encoding(pm)
283284
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
284285
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt))
285286

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,4 +409,15 @@ def TritonIntelGPUReduceVariableLiveness
409409
"mlir::scf::SCFDialect",
410410
"mlir::arith::ArithDialect"];
411411
}
412+
413+
def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> {
414+
let summary = "Set encodings on candidates for Subgroup 2D Block IO ops";
415+
416+
let description = [{
417+
Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a ConvertLayout op to the existing encoding replaces the result of the LoadOp.
418+
}];
419+
420+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::triton::TritonDialect"];
421+
}
422+
412423
#endif // TRITON_INTEL_GPU_PASSES

third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms
55
DistributeToWarps.cpp
66
MatchTargetSize.cpp
77
MaterializeBlockPointer.cpp
8+
OptimizeBlockIOEncoding.cpp
89
OptimizeDotOperands.cpp
910
OptimizeReductionLocality.cpp
1011
Pipeliner/MatmulLoopPipeline.cpp

0 commit comments

Comments
 (0)