diff --git a/test/TritonIntelGPU/dot-operands.mlir b/test/TritonIntelGPU/dot-operands.mlir index 8115b9d0f7..babdf2e34a 100644 --- a/test/TritonIntelGPU/dot-operands.mlir +++ b/test/TritonIntelGPU/dot-operands.mlir @@ -1,11 +1,10 @@ -// RUN: triton-opt %s -split-input-file -tritonintelgpu-optimize-dot-operands -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonintelgpu-optimize-dot-operands | FileCheck %s -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, not in a loop. - tt.func public @fuseLoadWithTrans1(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + tt.func public @fuseLoadWithTrans1(%arg0: !tt.ptr>>, %arg1: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c256_i32 = arith.constant 256 : i32 @@ -17,8 +16,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th %3 = tt.load %2 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> %4 = tt.trans %3 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> - %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> - tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> tt.return } // CHECK-LABEL: fuseLoadWithTrans1 @@ -31,13 +28,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. // COM: where the 'make_tensor_ptr' result is not loop carried. - tt.func public @fuseLoadWithTrans2(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr>) { + tt.func public @fuseLoadWithTrans2(%arg0: !tt.ptr>>, %arg1: !tt.ptr) { %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c32_i32 = arith.constant 32 : i32 @@ -55,8 +51,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th %6 = arith.addi %arg5, %c32_i32 : i32 scf.yield %5, %6 : tensor<256x256xf32, #mma>, i32 } - %6 = ttg.convert_layout %res#0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> - tt.store %arg2, %6 {boundaryCheck = array} : !tt.ptr> tt.return } // CHECK-LABEL: fuseLoadWithTrans2 @@ -68,3 +62,215 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> // CHECK: scf.yield } + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { + // COM: tt.load -> tt.trans -> tt.dot chain, in a loop. + // COM: where the 'make_tensor_ptr' result is loop carried. + tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %c16_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { + %17 = tt.advance %9, [%c256_i32, %arg5] : >> + %18 = tt.load %17 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %19 = tt.advance %arg6, [%c16_i32, %arg5] : > + %20 = tt.load %19 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %21 = tt.trans %20 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %23 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr> + } + tt.return + } + // CHECK-LABEL: fuseLoadWithTrans3 + // CHECK-NOT: tt.trans + // CHECK [[IDX1:%.*]] = arith.muli + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + // CHECK: scf.for {{.*}} iter_args({{.*}}, [[ARG5:%.*]] = {{.*}}, [[ARG6:%.*]] = [[PTR]]) + // CHECK: [[ADV:%.*]] = tt.advance [[ARG6]], [[[ARG5]], %c16_i32] : >> + // CHECK: [[LOAD_B:%.*]] = tt.load [[ADV]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + // CHECK: scf.yield {{.*}}, {{.*}}, [[ADV]] +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { + // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load + // COM: that 'feeds' the transpose operation is used. + tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %c16_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { + %17 = tt.advance %9, [%c256_i32, %arg5] : >> + %18 = tt.load %17 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %19 = tt.advance %arg6, [%c16_i32, %arg5] : > + %20 = tt.load %19 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %21 = tt.trans %20 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %23 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr> + } + %15 = tt.advance %13#2, [%c16_i32, %c16_i32] : > + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans1 + // CHECK: tt.trans +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { + // COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose. + // COM: In this case `%19` is used by the load that feeds the transpose and by a store operation. + tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c16_i32 = arith.constant 16 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %cst_1 = arith.constant dense<1.000000e+00> : tensor<256x32xbf16, #linear> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %c16_i32 : i32 + %6 = arith.remsi %5, %4 : i32 + %7 = arith.addi %2, %6 : i32 + %8 = arith.divsi %5, %4 : i32 + %9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr>) : i32 { + %17 = tt.advance %9, [%c256_i32, %arg5] : >> + %18 = tt.load %17 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %19 = tt.advance %arg6, [%c16_i32, %arg5] : > + %20 = tt.load %19 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %21 = tt.trans %20 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + tt.store %19, %cst_1 {boundaryCheck = array} : !tt.ptr> + %23 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr> + } + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans2 + // CHECK: tt.trans +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { + // COM: Ensure load is not fused with transpose if the block ptr used by the load operation is yielded by a if statement (current limitation). + tt.func public @doNotFuseLoadWithTrans3(%arg0: !tt.ptr>>, %arg1: !tt.ptr, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c1024_i32 = arith.constant 1024 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %a = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + + %res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32, #mma>, i32) : i32 { + %1 = tt.load %arg0 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %2 = tt.advance %0, [%c256_i32, %c0_i32] : > + %3 = scf.if %cond -> !tt.ptr> { + scf.yield %2 : !tt.ptr> + } else { + scf.yield %a : !tt.ptr> + } + %a4 = tt.load %3 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %b4 = tt.load %2 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %5 = tt.trans %b4 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %6 = tt.dot %1, %5, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %7 = arith.addi %arg5, %c32_i32 : i32 + scf.yield %6, %7 : tensor<256x256xf32, #mma>, i32 + } + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans3 + // CHECK: tt.trans +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { + // COM: Ensure load is not fused with transpose when it is in a while loop (current limitation). + tt.func public @doNotFuseLoadWithTrans4(%arg0: !tt.ptr>>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i32 = arith.constant 32 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %0 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array} : > + %1:2 = scf.while (%arg3 = %0, %arg4 = %c0_i32) : (!tt.ptr>, i32) -> (!tt.ptr>, i32) { + scf.condition(%arg2) %arg3, %arg4 : !tt.ptr>, i32 + } do { + ^bb0(%arg3: !tt.ptr>, %arg4: i32): + %2 = tt.load %arg0 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + %3 = tt.advance %arg3, [%c256_i32, %c0_i32] : > + %4 = tt.load %3 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %5 = tt.trans %4 {order = array} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %6 = tt.dot %2, %5, %cst, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %7 = arith.addi %arg4, %c32_i32 : i32 + scf.yield %3, %7 : !tt.ptr>, i32 + } + tt.return + } + // CHECK-LABEL: doNotFuseLoadWithTrans4 + // CHECK: tt.trans +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 853188d934..009f4b6c14 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_optimize_dot_operands(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) if (opt.reduce_variable_liveness): diff --git a/third_party/intel/include/Utils/Utility.h b/third_party/intel/include/Utils/Utility.h index 642cfcb1ef..f2842c1982 100644 --- a/third_party/intel/include/Utils/Utility.h +++ b/third_party/intel/include/Utils/Utility.h @@ -29,6 +29,9 @@ bool isConstant(Value val, int64_t expected); Value getFinalValue(Value value); +// Erase the operations in \p operations. +void eraseOperations(SmallPtrSetImpl &operations); + } // namespace mlir::triton::intel #endif // TRITON_INTEL_UTILS_UTILITY_H diff --git a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp index ebec88b95f..cf0fd2663d 100644 --- a/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp +++ b/third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp @@ -95,7 +95,9 @@ struct TritonIntelTensorDescToBlockPointer .Default([&](auto) { return WalkResult::advance(); }); }); - finalize(); + if (!cleanUp.empty()) + tt::intel::eraseOperations(cleanUp); + assert(succeeded(verify(moduleOp)) && "Module verification failed"); } @@ -267,31 +269,6 @@ struct TritonIntelTensorDescToBlockPointer return success(); } - void finalize() { - // Cleanup unused operations. - bool erasedOperation; - do { - erasedOperation = false; - SmallPtrSet erased; - for (Operation *op : cleanUp) { - if (!op->getUsers().empty() || !op->getRegions().empty()) - continue; - - erased.insert(op); - op->erase(); - erasedOperation = true; - } - cleanUp.remove_if([&](Operation *op) { return erased.contains(op); }); - } while (erasedOperation); - - // Remove operations that contain a region. - for (Operation *op : cleanUp) { - if (!op->getUsers().empty()) - continue; - op->erase(); - } - } - private: SmallPtrSet cleanUp; }; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index fbabe3ee11..bff9f17673 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -1,19 +1,23 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Utils/Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include #define DEBUG_TYPE "tritonintelgpu-optimize-dot-operands" @@ -43,27 +47,38 @@ namespace { // %load = tt.load %ptr, {blockIO=} // : tt.ptr // tt.dot(%a, %load) -class FuseTransWithLoad : public OpRewritePattern { +class FuseTransWithLoad { +private: + tt::FuncOp funcOp; + SmallPtrSet cleanUp; + public: - using OpRewritePattern::OpRewritePattern; + FuseTransWithLoad() = default; - LogicalResult matchAndRewrite(tt::TransOp transOp, - PatternRewriter &rewriter) const override { - if (!isCandidate(transOp)) - return failure(); + void run(ModuleOp moduleOp) { + moduleOp.walk([&](tt::TransOp transOp) { + if (isCandidate(transOp)) + fuse(transOp); + }); - LLVM_DEBUG(llvm::dbgs() << "Candidate: " << transOp << "\n"); - auto tensorType = cast(transOp.getType()); - Attribute dotEncoding = - cast(tensorType.getEncoding()); + if (!cleanUp.empty()) + tt::intel::eraseOperations(cleanUp); + + assert(succeeded(verify(moduleOp)) && "Module verification failed"); + } + + void fuse(tt::TransOp transOp) { + LLVM_DEBUG(llvm::dbgs() << "Found candidate:\n\t" << transOp << "\n"); auto loadOp = cast(transOp.getSrc().getDefiningOp()); tt::MakeTensorPtrOp makeTensorPtrOp = *triton::intel::findDefiningMakeTensorPtrOp(loadOp.getPtr()); - LLVM_DEBUG(llvm::dbgs() << "makeTensorPtrOp: " << makeTensorPtrOp << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "makeTensorPtrOp:\n\t" << makeTensorPtrOp << "\n"); // Create a MakeTensorPtrOp yielding a block pointer to the transposed - // tensor. + // tensor... auto ptrType = cast(makeTensorPtrOp.getType()); + auto tensorType = cast(transOp.getType()); auto newPtrType = tt::PointerType::get(tensorType, ptrType.getAddressSpace()); SmallVector newShape(llvm::reverse(makeTensorPtrOp.getShape())); @@ -75,49 +90,17 @@ class FuseTransWithLoad : public OpRewritePattern { makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(), newShape, newStrides, newOffsets, makeTensorPtrOp.getOrderAttr()); assert(makeTensorPtrOp->hasOneUse() && "Expecting single user"); - LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp: " << ptr << "\n"); - - // Transitively update users of the block pointer. - Operation *makeTensorPtrOpUser = *makeTensorPtrOp->getUsers().begin(); - if (auto advanceOp = dyn_cast(makeTensorPtrOpUser)) { - ptr = updateAdvanceOpChain(advanceOp, loadOp, ptr); - } else { - // TODO: handle loop init args (scf.for only for now). - assert(makeTensorPtrOpUser == loadOp && - "Expecting the load to be the user"); - } + LLVM_DEBUG(llvm::dbgs() << "newMakeTensorPtrOp:\n\t" << ptr << "\n"); - // Replace the load+transpose with a new load operation that uses the - // transposed block pointer. - auto newLoadOp = rewriter.create( - loadOp.getLoc(), ptr, loadOp.getMask(), loadOp.getOther(), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - - StringRef blockIOAttrName = - ttgi::TritonIntelGPUDialect::getBlockIOAttrName(); - StringAttr attr = loadOp->getAttrOfType(blockIOAttrName); - StringAttr newAttr = - (attr == "row_major") - ? StringAttr::get(loadOp->getContext(), "column_major") - : (attr == "column_major") - ? StringAttr::get(loadOp->getContext(), "row_major") - : nullptr; - assert(newAttr && "Expecting a valid blockIO attribute"); - - newLoadOp->setAttr(blockIOAttrName, newAttr); - LLVM_DEBUG(llvm::dbgs() << "newLoadOp: " << newLoadOp << "\n"); - - transOp->replaceAllUsesWith(newLoadOp); - - return success(); + // ... and propagate it through the def-use chain. + propagateToUsers(ptr, makeTensorPtrOp, makeTensorPtrOp, transOp); } private: // Candidate is of the form: // tt.dot(tt.trans(tt.load(..., {blockIO=...}))) // Where: - // - the transpose result is used only by the dot operation, and + // - the transpose result is used by the dot operation, and // - the transpose operation uses the result of a 2-dim load operation on a // block pointer (transitively) defined by a `make_tensor_ptr` in the same // function, and @@ -126,11 +109,26 @@ class FuseTransWithLoad : public OpRewritePattern { bool isCandidate(tt::TransOp transOp) const { assert(transOp && "Expecting a valid transpose operation"); - bool transOpUsedOnlyByDotOp = - transOp->hasOneUse() && - isa(*transOp->getUsers().begin()); + // Check whether \p transOp is used by a `dotOp` directly or indirectly + // (each operation in the def-use chain need to have a single user). + auto usedByDotOp = [](tt::TransOp transOp) { + if (!transOp->hasOneUse()) + return false; + + Operation *user = *transOp->getUsers().begin(); + while (user) { + if (isa(user)) + return true; + if (!user->hasOneUse()) + break; + user = *user->getUsers().begin(); + } + + return false; + }; + Attribute transOpEncoding = transOp.getType().getEncoding(); - if (!transOpUsedOnlyByDotOp || !transOpEncoding || + if (!usedByDotOp(transOp) || !transOpEncoding || !isa(transOpEncoding)) return false; @@ -162,74 +160,229 @@ class FuseTransWithLoad : public OpRewritePattern { return true; } + // Determine whether all operations in the def-use chain from \p start to + // \p end have a single user. + // Note: we allow an operation in the def-use chain to have an additional user + // if the operation is in a for loop, and the additional user is the yield + // operation, provided that the result yielded is not used after the loop. + // Example: + // make_tensor_ptr -> advance -> load (OK) + // make_tensor_ptr -> for init_arg -> advance -> load (OK) + // -> yield (OK) + // make_tensor_ptr -> for init_arg -> advance -> load (OK) + // -> yield -> load (NOT OK) + // bool singleUsersInChain(Operation *start, Operation *end) const { assert(start && end && "Expecting valid operations"); Operation *currentOp = start; + auto validate = [](Operation *op, Operation *&nextOp) { + assert(nextOp == nullptr); + + if (op->hasOneUse()) + return true; + if (!op->getParentOfType()) + return false; + + auto loopOp = op->getParentOfType(); + auto yieldOp = cast( + loopOp.getYieldedValues()[0].getParentBlock()->getTerminator()); + + SmallVector users(op->getUsers()); + if (users.size() > 2 || llvm::none_of(users, [&](Operation *user) { + return user == yieldOp; + })) + return false; + + auto yieldedValUsedAfterLoop = [&op, &yieldOp]() { + auto it = + llvm::find_if(yieldOp->getOpOperands(), [&op](OpOperand &operand) { + return operand.get() == op->getResult(0); + }); + assert(it != yieldOp->getOpOperands().end()); + OpOperand &operand = *it; + auto loopOp = cast(yieldOp->getParentOp()); + OpResult res = loopOp->getResult(operand.getOperandNumber()); + return !res.getUsers().empty(); + }; + + if (yieldedValUsedAfterLoop()) + return false; + + nextOp = *llvm::find_if( + users, [](Operation *user) { return !isa(user); }); + return true; + }; + while (currentOp != end) { - // TODO: extend to handle loops. - if ((currentOp->getNumRegions() != 0) || !currentOp->hasOneUse()) + Operation *user = nullptr; + if (!validate(currentOp, user)) { + LLVM_DEBUG(llvm::dbgs() + << "Fails safety checks: " << *currentOp << "\n"); return false; + } - currentOp = *currentOp->getUsers().begin(); + user = (!user) ? user = *currentOp->getUsers().begin() : user; + if (user->getNumRegions() == 0) { + currentOp = user; + continue; + } + + if (isa(user)) + return false; + + [[maybe_unused]] Operation *oldCurrentOp = currentOp; + + // Find the next operation in the def-use chain inside the loop body. + if (auto loopOp = dyn_cast(user)) { + for (auto [arg, init] : + llvm::zip(loopOp.getRegionIterArgs(), loopOp.getInits())) { + if (init == currentOp->getResult(0)) { + if (!arg.hasOneUse()) + return false; + + currentOp = *arg.getUsers().begin(); + break; + } + } + } + + assert(currentOp != oldCurrentOp && "Infinite loop detected!"); } return true; } - // Recursively update the operands in a chain of AdvanceOps, after setting the - // pointer operand of the first one. - tt::AdvanceOp updateAdvanceOpChain(tt::AdvanceOp advanceOp, tt::LoadOp loadOp, - Value ptr) const { - assert(advanceOp->hasOneUse() && "Expecting single user"); - assert(tt::isTensorPointerType(ptr.getType()) && - "Expecting a block pointer"); - - Operation *user = *advanceOp->getUsers().begin(); - if (auto loadUser = dyn_cast(user)) { - assert(loadUser == loadOp && - "chain should be terminated by candidate load"); - OpBuilder rewriter(advanceOp); - SmallVector newOffsets(llvm::reverse(advanceOp.getOffsets())); - return rewriter.create(advanceOp.getLoc(), ptr.getType(), - ptr, newOffsets); + // Propagate \p newVal to users of \p origOp. + void propagateToUsers(Value newVal, Value origVal, Operation *origOp, + Operation *sentinel) { + assert(origOp && sentinel && "Expecting valid operations"); + const SmallVector users(origOp->getUsers()); + for (Operation *user : users) + propagateToUser(newVal, origVal, user, sentinel); + } + + // If \p user is not \p sentinel, propagate \p newVal to \p user. Otherwise + // terminate the propagation. + void propagateToUser(Value newVal, Value origVal, Operation *user, + Operation *sentinel) { + assert(user && sentinel && "Expecting valid operations"); + assert(llvm::is_contained(origVal.getUsers(), user) && "Invalid usage"); + + LLVM_DEBUG({ + llvm::dbgs() << "In " << __func__ << "\n"; + llvm::dbgs() << "user of "; + if (origVal.getDefiningOp()) { + llvm::dbgs() << "\n\t" << *origVal.getDefiningOp() << "\n"; + } else { + origVal.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << " "; + } + llvm::dbgs() << "is:\n\t"; + user->dumpPretty(); + }); + + if (user == sentinel) { + LLVM_DEBUG(llvm::dbgs() << "Reached sentinel\n"); + sentinel->replaceAllUsesWith(newVal.getDefiningOp()); + cleanUp.insert(sentinel); + return; } + Location loc = user->getLoc(); if (auto advanceOp = dyn_cast(user)) { OpBuilder rewriter(advanceOp); SmallVector newOffsets(llvm::reverse(advanceOp.getOffsets())); - ptr = rewriter.create(advanceOp.getLoc(), ptr.getType(), - ptr, newOffsets); - return updateAdvanceOpChain(advanceOp, loadOp, ptr); + auto newAdvanceOp = rewriter.create(loc, newVal.getType(), + newVal, newOffsets); + LLVM_DEBUG(llvm::dbgs() << "\tnewAdvanceOp: " << newAdvanceOp << "\n"); + cleanUp.insert(advanceOp); + return propagateToUsers(newAdvanceOp, advanceOp.getResult(), advanceOp, + sentinel); } - llvm_unreachable("Unexpected user"); - return nullptr; + if (auto loadOp = dyn_cast(user)) { + OpBuilder rewriter(loadOp); + auto newLoadOp = rewriter.create( + loadOp.getLoc(), newVal, loadOp.getMask(), loadOp.getOther(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + + StringRef blockIOAttrName = + ttgi::TritonIntelGPUDialect::getBlockIOAttrName(); + StringAttr attr = loadOp->getAttrOfType(blockIOAttrName); + StringAttr newAttr = + (attr == "row_major") + ? StringAttr::get(loadOp->getContext(), "column_major") + : (attr == "column_major") + ? StringAttr::get(loadOp->getContext(), "row_major") + : nullptr; + assert(newAttr && "Expecting a valid blockIO attribute"); + + newLoadOp->setAttr(blockIOAttrName, newAttr); + LLVM_DEBUG(llvm::dbgs() << "\tnewLoadOp: " << newLoadOp << "\n"); + cleanUp.insert(loadOp); + return propagateToUsers(newLoadOp, loadOp.getResult(), loadOp, sentinel); + } + + if (auto yieldOp = dyn_cast(user)) { + int opNum = -1; + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == origVal) { + opNum = operand.getOperandNumber(); + yieldOp->setOperand(operand.getOperandNumber(), newVal); + break; + } + } + + // Update the yield's parent operation result type. + Operation *parentOp = yieldOp->getParentOp(); + OpResult res = parentOp->getOpResult(opNum); + res.setType(newVal.getType()); + return; + } + + if (auto forOp = dyn_cast(user)) + return propagateToLoop(newVal, origVal, forOp, sentinel); + + llvm_unreachable("Unexpected kind of user"); + } + + void propagateToLoop(Value newVal, Value origVal, LoopLikeOpInterface loopOp, + Operation *sentinel) { + assert(sentinel && sentinel != loopOp && "Unexpected sentinel kind"); + LLVM_DEBUG({ + llvm::dbgs() << "In " << __func__ << "\n"; + llvm::dbgs() << "newVal: " << newVal << "\n"; + }); + + for (auto [initArg, rgnInitArg] : + llvm::zip(loopOp.getInitsMutable(), loopOp.getRegionIterArgs())) { + if (initArg.get() == origVal) { + initArg.set(newVal); + rgnInitArg.setType(initArg.get().getType()); + const SmallVector users(rgnInitArg.getUsers()); + for (Operation *user : users) + propagateToUser(rgnInitArg, rgnInitArg, user, sentinel); + } + } } }; } // namespace class TritonIntelGPUOptimizeDotOperandsPass - : public triton::gpu::intel::impl::TritonIntelGPUOptimizeDotOperandsBase< + : public ttgi::impl::TritonIntelGPUOptimizeDotOperandsBase< TritonIntelGPUOptimizeDotOperandsPass> { + public: - using triton::gpu::intel::impl::TritonIntelGPUOptimizeDotOperandsBase< + using ttgi::impl::TritonIntelGPUOptimizeDotOperandsBase< TritonIntelGPUOptimizeDotOperandsPass>:: TritonIntelGPUOptimizeDotOperandsBase; - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp m = getOperation(); - - OpPassManager pm; - pm.addPass(mlir::createCanonicalizerPass()); - if (failed(runPipeline(pm, m))) - return signalPassFailure(); - - mlir::RewritePatternSet patterns(context); - patterns.add(context); - if (failed(applyPatternsGreedily(m, std::move(patterns)))) - signalPassFailure(); + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + FuseTransWithLoad fuser; + fuser.run(moduleOp); } }; diff --git a/third_party/intel/lib/Utils/Utility.cpp b/third_party/intel/lib/Utils/Utility.cpp index 277e7c5807..2f90263059 100644 --- a/third_party/intel/lib/Utils/Utility.cpp +++ b/third_party/intel/lib/Utils/Utility.cpp @@ -193,4 +193,28 @@ Value getFinalValue(Value value) { return value; } +void eraseOperations(SmallPtrSetImpl &operations) { + bool erasedOperation; + do { + erasedOperation = false; + SmallPtrSet erased; + for (Operation *op : operations) { + if (!op->getUsers().empty() || !op->getRegions().empty()) + continue; + + erased.insert(op); + op->erase(); + erasedOperation = true; + } + operations.remove_if([&](Operation *op) { return erased.contains(op); }); + } while (erasedOperation); + + // Remove operations that contain a region. + for (Operation *op : operations) { + if (!op->getUsers().empty()) + continue; + op->erase(); + } +} + } // namespace mlir::triton::intel diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index a77e9759c3..3e0de55552 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -77,6 +77,8 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createIntelAllocateSharedMemory); ADD_PASS_WRAPPER_0("add_remove_layout_conversions", gpu::intel::createTritonIntelGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_optimize_dot_operands", + gpu::intel::createTritonIntelGPUOptimizeDotOperands); ADD_PASS_WRAPPER_0("add_coalesce", gpu::intel::createTritonIntelGPUCoalesce); ADD_PASS_OPTION_WRAPPER_2("add_prefetch_block", gpu::intel::createTritonIntelGPUPrefetchBlock, int,