Skip to content

Commit 87f0227

Browse files
thevar1abletstellar
authored andcommitted
[InstCombine] Avoid folding select(umin(X, Y), X) with min/max values in false arm (#143020)
Fixes #139050. This patch adds a check to avoid folding min/max reduction into select, which may block loop vectorization. The issue is that the following snippet: ``` declare i8 @llvm.umin.i8(i8, i8) define i8 @masked_min_fold_bug(i8 %acc, i8 %val, i8 %mask) { ; CHECK-LABEL: @masked_min_fold_bug( ; CHECK: %cond = icmp eq i8 %mask, 0 ; CHECK: %masked_val = select i1 %cond, i8 %val, i8 255 ; CHECK: call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val) ; %cond = icmp eq i8 %mask, 0 %masked_val = select i1 %cond, i8 %val, i8 255 %res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val) ret i8 %res } ``` is being optimized to the following code, which can not be vectorized later. ``` declare i8 @llvm.umin.i8(i8, i8) #0 define i8 @masked_min_fold_bug(i8 %acc, i8 %val, i8 %mask) { %cond = icmp eq i8 %mask, 0 %1 = call i8 @llvm.umin.i8(i8 %acc, i8 %val) %res = select i1 %cond, i8 %1, i8 %acc ret i8 %res } attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } ``` Expected: ``` declare i8 @llvm.umin.i8(i8, i8) #0 define i8 @masked_min_fold_bug(i8 %acc, i8 %val, i8 %mask) { %cond = icmp eq i8 %mask, 0 %masked_val = select i1 %cond, i8 %val, i8 -1 %res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val) ret i8 %res } attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } ``` https://godbolt.org/z/cYMheKE5r (cherry picked from commit 07fa6d1)
1 parent df43f93 commit 87f0227

File tree

3 files changed

+94
-12
lines changed

3 files changed

+94
-12
lines changed

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,15 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
16971697
if (SI->getType()->isIntOrIntVectorTy(1))
16981698
return nullptr;
16991699

1700+
// Avoid breaking min/max reduction pattern,
1701+
// which is necessary for vectorization later.
1702+
if (isa<MinMaxIntrinsic>(&Op))
1703+
for (Value *IntrinOp : Op.operands())
1704+
if (auto *PN = dyn_cast<PHINode>(IntrinOp))
1705+
for (Value *PhiOp : PN->operands())
1706+
if (PhiOp == &Op)
1707+
return nullptr;
1708+
17001709
// Test if a FCmpInst instruction is used exclusively by a select as
17011710
// part of a minimum or maximum operation. If so, refrain from doing
17021711
// any other folding. This helps out other analyses which understand

llvm/test/Transforms/InstCombine/select.ll

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4901,3 +4901,50 @@ define i32 @src_simplify_2x_at_once_and(i32 %x, i32 %y) {
49014901
%cond = select i1 %and0, i32 %sub, i32 %xor
49024902
ret i32 %cond
49034903
}
4904+
4905+
define void @no_fold_masked_min_loop(ptr nocapture readonly %vals, ptr nocapture readonly %masks, ptr nocapture %out, i64 %n) {
4906+
; CHECK-LABEL: @no_fold_masked_min_loop(
4907+
; CHECK-NEXT: entry:
4908+
; CHECK-NEXT: br label [[LOOP:%.*]]
4909+
; CHECK: loop:
4910+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[NEXT_INDEX:%.*]], [[LOOP]] ]
4911+
; CHECK-NEXT: [[ACC:%.*]] = phi i8 [ -1, [[ENTRY]] ], [ [[RES:%.*]], [[LOOP]] ]
4912+
; CHECK-NEXT: [[VAL_PTR:%.*]] = getelementptr inbounds i8, ptr [[VALS:%.*]], i64 [[INDEX]]
4913+
; CHECK-NEXT: [[MASK_PTR:%.*]] = getelementptr inbounds i8, ptr [[MASKS:%.*]], i64 [[INDEX]]
4914+
; CHECK-NEXT: [[VAL:%.*]] = load i8, ptr [[VAL_PTR]], align 1
4915+
; CHECK-NEXT: [[MASK:%.*]] = load i8, ptr [[MASK_PTR]], align 1
4916+
; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[MASK]], 0
4917+
; CHECK-NEXT: [[MASKED_VAL:%.*]] = select i1 [[COND]], i8 [[VAL]], i8 -1
4918+
; CHECK-NEXT: [[RES]] = call i8 @llvm.umin.i8(i8 [[ACC]], i8 [[MASKED_VAL]])
4919+
; CHECK-NEXT: [[NEXT_INDEX]] = add i64 [[INDEX]], 1
4920+
; CHECK-NEXT: [[DONE:%.*]] = icmp eq i64 [[NEXT_INDEX]], [[N:%.*]]
4921+
; CHECK-NEXT: br i1 [[DONE]], label [[EXIT:%.*]], label [[LOOP]]
4922+
; CHECK: exit:
4923+
; CHECK-NEXT: store i8 [[RES]], ptr [[OUT:%.*]], align 1
4924+
; CHECK-NEXT: ret void
4925+
;
4926+
entry:
4927+
br label %loop
4928+
4929+
loop:
4930+
%index = phi i64 [0, %entry], [%next_index, %loop]
4931+
%acc = phi i8 [255, %entry], [%res, %loop]
4932+
4933+
%val_ptr = getelementptr inbounds i8, ptr %vals, i64 %index
4934+
%mask_ptr = getelementptr inbounds i8, ptr %masks, i64 %index
4935+
4936+
%val = load i8, ptr %val_ptr, align 1
4937+
%mask = load i8, ptr %mask_ptr, align 1
4938+
4939+
%cond = icmp eq i8 %mask, 0
4940+
%masked_val = select i1 %cond, i8 %val, i8 -1
4941+
%res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val)
4942+
4943+
%next_index = add i64 %index, 1
4944+
%done = icmp eq i64 %next_index, %n
4945+
br i1 %done, label %exit, label %loop
4946+
4947+
exit:
4948+
store i8 %res, ptr %out, align 1
4949+
ret void
4950+
}

llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -326,26 +326,52 @@ cleanup:
326326
ret i1 %retval.0
327327
}
328328

329-
; From https://github.com/llvm/llvm-project/issues/139050.
330-
; FIXME: This should be vectorized.
331329
define i8 @masked_min_reduction(ptr %data, ptr %mask) {
332330
; CHECK-LABEL: @masked_min_reduction(
333331
; CHECK-NEXT: entry:
334332
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
335-
; CHECK: loop:
333+
; CHECK: vector.body:
336334
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
337-
; CHECK-NEXT: [[ACC:%.*]] = phi i8 [ -1, [[ENTRY]] ], [ [[TMP21:%.*]], [[VECTOR_BODY]] ]
335+
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
336+
; CHECK-NEXT: [[VEC_PHI1:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP17:%.*]], [[VECTOR_BODY]] ]
337+
; CHECK-NEXT: [[VEC_PHI2:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
338+
; CHECK-NEXT: [[VEC_PHI3:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
338339
; CHECK-NEXT: [[DATA:%.*]] = getelementptr i8, ptr [[DATA1:%.*]], i64 [[INDEX]]
339-
; CHECK-NEXT: [[VAL:%.*]] = load i8, ptr [[DATA]], align 1
340+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[DATA]], i64 32
341+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[DATA]], i64 64
342+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[DATA]], i64 96
343+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <32 x i8>, ptr [[DATA]], align 1
344+
; CHECK-NEXT: [[WIDE_LOAD4:%.*]] = load <32 x i8>, ptr [[TMP1]], align 1
345+
; CHECK-NEXT: [[WIDE_LOAD5:%.*]] = load <32 x i8>, ptr [[TMP2]], align 1
346+
; CHECK-NEXT: [[WIDE_LOAD6:%.*]] = load <32 x i8>, ptr [[TMP3]], align 1
340347
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[MASK:%.*]], i64 [[INDEX]]
341-
; CHECK-NEXT: [[M:%.*]] = load i8, ptr [[TMP7]], align 1
342-
; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[M]], 0
343-
; CHECK-NEXT: [[TMP0:%.*]] = tail call i8 @llvm.umin.i8(i8 [[ACC]], i8 [[VAL]])
344-
; CHECK-NEXT: [[TMP21]] = select i1 [[COND]], i8 [[TMP0]], i8 [[ACC]]
345-
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw nsw i64 [[INDEX]], 1
348+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP7]], i64 32
349+
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP7]], i64 64
350+
; CHECK-NEXT: [[TMP22:%.*]] = getelementptr i8, ptr [[TMP7]], i64 96
351+
; CHECK-NEXT: [[WIDE_LOAD7:%.*]] = load <32 x i8>, ptr [[TMP7]], align 1
352+
; CHECK-NEXT: [[WIDE_LOAD8:%.*]] = load <32 x i8>, ptr [[TMP5]], align 1
353+
; CHECK-NEXT: [[WIDE_LOAD9:%.*]] = load <32 x i8>, ptr [[TMP6]], align 1
354+
; CHECK-NEXT: [[WIDE_LOAD10:%.*]] = load <32 x i8>, ptr [[TMP22]], align 1
355+
; CHECK-NEXT: [[TMP8:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD7]], zeroinitializer
356+
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD8]], zeroinitializer
357+
; CHECK-NEXT: [[TMP10:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD9]], zeroinitializer
358+
; CHECK-NEXT: [[TMP11:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD10]], zeroinitializer
359+
; CHECK-NEXT: [[TMP12:%.*]] = select <32 x i1> [[TMP8]], <32 x i8> [[WIDE_LOAD]], <32 x i8> splat (i8 -1)
360+
; CHECK-NEXT: [[TMP13:%.*]] = select <32 x i1> [[TMP9]], <32 x i8> [[WIDE_LOAD4]], <32 x i8> splat (i8 -1)
361+
; CHECK-NEXT: [[TMP14:%.*]] = select <32 x i1> [[TMP10]], <32 x i8> [[WIDE_LOAD5]], <32 x i8> splat (i8 -1)
362+
; CHECK-NEXT: [[TMP15:%.*]] = select <32 x i1> [[TMP11]], <32 x i8> [[WIDE_LOAD6]], <32 x i8> splat (i8 -1)
363+
; CHECK-NEXT: [[TMP16]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI]], <32 x i8> [[TMP12]])
364+
; CHECK-NEXT: [[TMP17]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI1]], <32 x i8> [[TMP13]])
365+
; CHECK-NEXT: [[TMP18]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI2]], <32 x i8> [[TMP14]])
366+
; CHECK-NEXT: [[TMP19]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI3]], <32 x i8> [[TMP15]])
367+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 128
346368
; CHECK-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
347-
; CHECK-NEXT: br i1 [[TMP20]], label [[EXIT:%.*]], label [[VECTOR_BODY]]
348-
; CHECK: exit:
369+
; CHECK-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
370+
; CHECK: middle.block:
371+
; CHECK-NEXT: [[RDX_MINMAX:%.*]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[TMP16]], <32 x i8> [[TMP17]])
372+
; CHECK-NEXT: [[RDX_MINMAX11:%.*]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[RDX_MINMAX]], <32 x i8> [[TMP18]])
373+
; CHECK-NEXT: [[RDX_MINMAX12:%.*]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[RDX_MINMAX11]], <32 x i8> [[TMP19]])
374+
; CHECK-NEXT: [[TMP21:%.*]] = tail call i8 @llvm.vector.reduce.umin.v32i8(<32 x i8> [[RDX_MINMAX12]])
349375
; CHECK-NEXT: ret i8 [[TMP21]]
350376
;
351377
entry:

0 commit comments

Comments
 (0)