Skip to content

Commit c923a94

Browse files
committed
!fixup generalize, create wide IV if needed.
1 parent 08d5635 commit c923a94

File tree

12 files changed

+752
-486
lines changed

12 files changed

+752
-486
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8587,7 +8587,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
85878587

85888588
// Apply mandatory transformation to handle reductions with multiple in-loop
85898589
// uses if possible, bail out otherwise.
8590-
if (!VPlanTransforms::handleMultiUseReductions(*Plan, *PSE.getSE(), OrigLoop))
8590+
if (!VPlanTransforms::handleMultiUseReductions(*Plan))
85918591
return nullptr;
85928592
// Apply mandatory transformation to handle FP maxnum/minnum reduction with
85938593
// NaNs if possible, bail out otherwise.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2155,7 +2155,9 @@ class VPWidenInductionRecipe : public VPHeaderPHIRecipe {
21552155
/// incoming value, its start value.
21562156
unsigned getNumIncoming() const override { return 1; }
21572157

2158-
PHINode *getPHINode() const { return cast<PHINode>(getUnderlyingValue()); }
2158+
PHINode *getPHINode() const {
2159+
return cast_if_present<PHINode>(getUnderlyingValue());
2160+
}
21592161

21602162
/// Returns the induction descriptor for the recipe.
21612163
const InductionDescriptor &getInductionDescriptor() const { return IndDesc; }

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 133 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,48 +1000,139 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
10001000
return true;
10011001
}
10021002

1003-
/// Try to convert FindLastIV to FindFirstIV reduction when using a strict
1004-
/// predicate. Returns the new FindFirstIVPhiR on success, nullptr on failure.
1005-
static VPReductionPHIRecipe *
1006-
tryConvertToFindFirstIV(VPlan &Plan, VPReductionPHIRecipe *FindLastIVPhiR,
1007-
VPValue *IVOp, ScalarEvolution &SE, const Loop *L) {
1008-
Type *Ty = VPTypeAnalysis(Plan).inferScalarType(FindLastIVPhiR);
1009-
unsigned NumBits = Ty->getIntegerBitWidth();
1010-
1011-
// Determine the reduction kind and sentinel based on the IV range.
1012-
RecurKind NewKind;
1013-
VPValue *NewSentinel;
1014-
auto *AR = cast<SCEVAddRecExpr>(vputils::getSCEVExprForVPValue(IVOp, SE, L));
1015-
if (RecurrenceDescriptor::isValidIVRangeForFindIV(
1016-
AR, /*IsSigned=*/true, /*IsFindFirstIV=*/true, SE)) {
1017-
NewKind = RecurKind::FindFirstIVSMin;
1018-
NewSentinel = Plan.getConstantInt(APInt::getSignedMaxValue(NumBits));
1019-
} else if (RecurrenceDescriptor::isValidIVRangeForFindIV(
1020-
AR, /*IsSigned=*/false, /*IsFindFirstIV=*/true, SE)) {
1021-
NewKind = RecurKind::FindFirstIVUMin;
1022-
NewSentinel = Plan.getConstantInt(APInt::getMaxValue(NumBits));
1023-
} else {
1024-
return nullptr;
1003+
/// For argmin/argmax reductions with strict predicates, convert the existing
1004+
/// FindLastIV reduction to a new UMin reduction of a wide canonical IV. If the
1005+
/// original IV was not canonical, a new canonical wide IV is added, and the
1006+
/// final result is scaled back to the original IV.
1007+
static bool handleStrictArgMinArgMax(VPlan &Plan,
1008+
VPReductionPHIRecipe *MinMaxPhiR,
1009+
VPReductionPHIRecipe *FindIVPhiR,
1010+
VPWidenIntOrFpInductionRecipe *WideIV,
1011+
VPInstruction *MinMaxResult) {
1012+
Type *Ty = Plan.getVectorLoopRegion()->getCanonicalIVType();
1013+
if (Ty != VPTypeAnalysis(Plan).inferScalarType(FindIVPhiR))
1014+
return false;
1015+
1016+
// If the original wide IV is not canonical, create a new one. The wide IV is
1017+
// guaranteed to not wrap for all lanes that are active in the vector loop.
1018+
if (!WideIV->isCanonical()) {
1019+
VPValue *Zero = Plan.getOrAddLiveIn(ConstantInt::get(Ty, 0));
1020+
VPValue *One = Plan.getOrAddLiveIn(ConstantInt::get(Ty, 1));
1021+
auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe(
1022+
nullptr, Zero, One, WideIV->getVFValue(),
1023+
WideIV->getInductionDescriptor(), VPIRFlags(), WideIV->getDebugLoc());
1024+
WidenCanIV->insertBefore(WideIV);
1025+
1026+
// Update the select to use the wide canonical IV.
1027+
auto *SelectRecipe = cast<VPSingleDefRecipe>(
1028+
FindIVPhiR->getBackedgeValue()->getDefiningRecipe());
1029+
if (SelectRecipe->getOperand(1) == WideIV)
1030+
SelectRecipe->setOperand(1, WidenCanIV);
1031+
else if (SelectRecipe->getOperand(2) == WideIV)
1032+
SelectRecipe->setOperand(2, WidenCanIV);
10251033
}
10261034

1027-
// Create the new FindFirstIV reduction recipe.
1028-
assert(!FindLastIVPhiR->isInLoop() && !FindLastIVPhiR->isOrdered());
1029-
ReductionStyle Style = RdxUnordered{FindLastIVPhiR->getVFScaleFactor()};
1030-
auto *FindFirstIVPhiR =
1031-
new VPReductionPHIRecipe(nullptr, NewKind, *NewSentinel, Style,
1032-
FindLastIVPhiR->hasUsesOutsideReductionChain());
1033-
FindFirstIVPhiR->addOperand(FindLastIVPhiR->getBackedgeValue());
1035+
// Create the new UMin reduction recipe to track the minimum index.
1036+
assert(!FindIVPhiR->isInLoop() && !FindIVPhiR->isOrdered() &&
1037+
"inloop and ordered reductions not supported");
1038+
VPValue *MaxInt =
1039+
Plan.getConstantInt(APInt::getMaxValue(Ty->getIntegerBitWidth()));
1040+
ReductionStyle Style = RdxUnordered{FindIVPhiR->getVFScaleFactor()};
1041+
auto *MinIdxPhiR = new VPReductionPHIRecipe(
1042+
dyn_cast_or_null<PHINode>(FindIVPhiR->getUnderlyingValue()),
1043+
RecurKind::UMin, *MaxInt, Style,
1044+
FindIVPhiR->hasUsesOutsideReductionChain());
1045+
MinIdxPhiR->addOperand(FindIVPhiR->getBackedgeValue());
1046+
MinIdxPhiR->insertBefore(FindIVPhiR);
10341047

1035-
FindFirstIVPhiR->insertBefore(FindLastIVPhiR);
10361048
VPInstruction *FindLastIVResult =
1037-
findUserOf<VPInstruction::ComputeFindIVResult>(FindLastIVPhiR);
1038-
FindLastIVPhiR->replaceAllUsesWith(FindFirstIVPhiR);
1039-
FindLastIVResult->setOperand(2, NewSentinel);
1040-
return FindFirstIVPhiR;
1049+
findUserOf<VPInstruction::ComputeFindIVResult>(FindIVPhiR);
1050+
VPInstruction *MinMaxresult =
1051+
findUserOf<VPInstruction::ComputeReductionResult>(MinMaxPhiR);
1052+
MinMaxresult->moveBefore(*FindLastIVResult->getParent(),
1053+
FindLastIVResult->getIterator());
1054+
1055+
// The reduction using MinMaxPhiR needs adjusting to compute the correct
1056+
// result:
1057+
// 1. We need to find the first canonical IV for which the condition based on
1058+
// the
1059+
// min/max recurrence is true,
1060+
// 2. Compare the partial min/max reduction result to its final value and,
1061+
// 3. Select the lanes of the partial UMin reduction of the canonical wide IV
1062+
// which
1063+
// correspond to the lanes matching the min/max reduction result.
1064+
// 4. Scale the final select canonical IV back to the original IV using
1065+
// VPDerivedIVRecipe.
1066+
// 5. If the minimum value matches the start value, the condition in the loop
1067+
// was never true, return the start value in that case.
1068+
//
1069+
// The original reductions need adjusting:
1070+
// For example, this transforms
1071+
// vp<%min.result> = compute-reduction-result ir<%min.val>,
1072+
// ir<%min.val.next>
1073+
// vp<%find.iv.result = compute-find-iv-result ir<%min.idx>, ir<0>,
1074+
// SENTINEL, vp<%min.idx.next>
1075+
//
1076+
// into:
1077+
// vp<%min.result> = compute-reduction-result ir<%min.val>, ir<%min.val.next>
1078+
// vp<%final.min.cmp> = icmp eq ir<%min.val.next>, vp<%min.result>
1079+
// vp<%final.min.iv> = select vp<%final.min.cmp>, ir<%min.idx.next>, ir<-1>
1080+
// vp<%13> = compute-reduction-result ir<%min.idx>, vp<%final.min.iv>
1081+
// vp<%scaled.result.iv> = DERIVED-IV ir<20> + vp<%13> * ir<1>
1082+
// vp<%threshold.cmp> = icmp slt vp<%min.result>, ir<0>
1083+
// vp<%final.result> = select vp<%threshold.cmp>, vp<%scaled.result.iv>,
1084+
// ir<%original.start>
1085+
1086+
VPBuilder Builder(FindLastIVResult);
1087+
VPValue *MinMaxExiting = MinMaxResult->getOperand(1);
1088+
auto *FinalMinMaxCmp =
1089+
Builder.createICmp(CmpInst::ICMP_EQ, MinMaxExiting, MinMaxResult);
1090+
VPValue *LastIVExiting = FindLastIVResult->getOperand(3);
1091+
auto *FinalIVSelect =
1092+
Builder.createSelect(FinalMinMaxCmp, LastIVExiting, MaxInt);
1093+
VPSingleDefRecipe *FinalResult = Builder.createNaryOp(
1094+
VPInstruction::ComputeReductionResult, {MinIdxPhiR, FinalIVSelect}, {},
1095+
FindLastIVResult->getDebugLoc());
1096+
1097+
// If we used a new wide canonical IV convert the reduction result back to the
1098+
// original IV scale before the final select.
1099+
if (!WideIV->isCanonical()) {
1100+
auto *DerivedIVRecipe =
1101+
new VPDerivedIVRecipe(InductionDescriptor::IK_IntInduction,
1102+
nullptr, // No FPBinOp for integer induction
1103+
WideIV->getStartValue(), FinalResult,
1104+
WideIV->getStepValue(), "derived.iv.result");
1105+
DerivedIVRecipe->insertBefore(&*Builder.getInsertPoint());
1106+
FinalResult = DerivedIVRecipe;
1107+
}
1108+
1109+
auto GetPred = [&]() {
1110+
switch (MinMaxPhiR->getRecurrenceKind()) {
1111+
case RecurKind::UMin:
1112+
return CmpInst::ICMP_ULT;
1113+
case RecurKind::SMin:
1114+
return CmpInst::ICMP_SLT;
1115+
case RecurKind::UMax:
1116+
return CmpInst::ICMP_UGT;
1117+
case RecurKind::SMax:
1118+
return CmpInst::ICMP_SGT;
1119+
default:
1120+
llvm_unreachable("");
1121+
}
1122+
};
1123+
// If the final min/max value matches the start value, the condition in the
1124+
// loop was always false, i.e. no induction value has been selected. If that's
1125+
// the case, use the original start value.
1126+
VPValue *MinMaxLT =
1127+
Builder.createICmp(GetPred(), MinMaxresult, MinMaxPhiR->getStartValue());
1128+
VPValue *Res = Builder.createSelect(MinMaxLT, FinalResult,
1129+
FindLastIVResult->getOperand(1));
1130+
FindIVPhiR->replaceAllUsesWith(MinIdxPhiR);
1131+
FindLastIVResult->replaceAllUsesWith(Res);
1132+
return true;
10411133
}
10421134

1043-
bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
1044-
const Loop *L) {
1135+
bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan) {
10451136
for (auto &PhiR : make_early_inc_range(
10461137
Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis())) {
10471138
auto *MinMaxPhiR = dyn_cast<VPReductionPHIRecipe>(&PhiR);
@@ -1151,13 +1242,14 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
11511242
if (!IsValidPredicate)
11521243
return false;
11531244

1154-
// For strict predicates, transform try to convert FindLastIV to
1155-
// FindFirstIV.
1245+
// For strict predicates, use a UMin reduction to find the minimum index.
1246+
// Canonical IVs (0, 1, 2, ...) are guaranteed not to wrap in the vector
1247+
// loop, so UMin can always be used.
11561248
bool IsStrictPredicate = ICmpInst::isLT(Pred) || ICmpInst::isGT(Pred);
11571249
if (IsStrictPredicate) {
1158-
FindIVPhiR = tryConvertToFindFirstIV(Plan, FindIVPhiR, IVOp, SE, L);
1159-
if (!FindIVPhiR)
1160-
return false;
1250+
return handleStrictArgMinArgMax(Plan, MinMaxPhiR, FindIVPhiR,
1251+
cast<VPWidenIntOrFpInductionRecipe>(IVOp),
1252+
MinMaxResult);
11611253
}
11621254

11631255
// The reduction using MinMaxPhiR needs adjusting to compute the correct

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,6 +1190,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
11901190
case VPInstruction::BuildVector:
11911191
case VPInstruction::CalculateTripCountMinusVF:
11921192
case VPInstruction::CanonicalIVIncrementForPart:
1193+
case VPInstruction::ComputeFindIVResult:
11931194
case VPInstruction::ExtractLane:
11941195
case VPInstruction::ExtractLastLane:
11951196
case VPInstruction::ExtractLastPart:

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ struct VPlanTransforms {
148148
/// Try to legalize reductions with multiple in-loop uses. Currently only
149149
/// min/max reductions used by FindLastIV and FindFirstIV reductions are
150150
/// supported. Otherwise return false.
151-
static bool handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
152-
const Loop *L);
151+
static bool handleMultiUseReductions(VPlan &Plan);
153152

154153
/// Try to have all users of fixed-order recurrences appear after the recipe
155154
/// defining their previous value, by either sinking users or hoisting recipes

0 commit comments

Comments
 (0)