@@ -1000,48 +1000,139 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
10001000 return true ;
10011001}
10021002
1003- // / Try to convert FindLastIV to FindFirstIV reduction when using a strict
1004- // / predicate. Returns the new FindFirstIVPhiR on success, nullptr on failure.
1005- static VPReductionPHIRecipe *
1006- tryConvertToFindFirstIV (VPlan &Plan, VPReductionPHIRecipe *FindLastIVPhiR,
1007- VPValue *IVOp, ScalarEvolution &SE, const Loop *L) {
1008- Type *Ty = VPTypeAnalysis (Plan).inferScalarType (FindLastIVPhiR);
1009- unsigned NumBits = Ty->getIntegerBitWidth ();
1010-
1011- // Determine the reduction kind and sentinel based on the IV range.
1012- RecurKind NewKind;
1013- VPValue *NewSentinel;
1014- auto *AR = cast<SCEVAddRecExpr>(vputils::getSCEVExprForVPValue (IVOp, SE, L));
1015- if (RecurrenceDescriptor::isValidIVRangeForFindIV (
1016- AR, /* IsSigned=*/ true , /* IsFindFirstIV=*/ true , SE)) {
1017- NewKind = RecurKind::FindFirstIVSMin;
1018- NewSentinel = Plan.getConstantInt (APInt::getSignedMaxValue (NumBits));
1019- } else if (RecurrenceDescriptor::isValidIVRangeForFindIV (
1020- AR, /* IsSigned=*/ false , /* IsFindFirstIV=*/ true , SE)) {
1021- NewKind = RecurKind::FindFirstIVUMin;
1022- NewSentinel = Plan.getConstantInt (APInt::getMaxValue (NumBits));
1023- } else {
1024- return nullptr ;
1003+ // / For argmin/argmax reductions with strict predicates, convert the existing
1004+ // / FindLastIV reduction to a new UMin reduction of a wide canonical IV. If the
1005+ // / original IV was not canonical, a new canonical wide IV is added, and the
1006+ // / final result is scaled back to the original IV.
1007+ static bool handleStrictArgMinArgMax (VPlan &Plan,
1008+ VPReductionPHIRecipe *MinMaxPhiR,
1009+ VPReductionPHIRecipe *FindIVPhiR,
1010+ VPWidenIntOrFpInductionRecipe *WideIV,
1011+ VPInstruction *MinMaxResult) {
1012+ Type *Ty = Plan.getVectorLoopRegion ()->getCanonicalIVType ();
1013+ if (Ty != VPTypeAnalysis (Plan).inferScalarType (FindIVPhiR))
1014+ return false ;
1015+
1016+ // If the original wide IV is not canonical, create a new one. The wide IV is
1017+ // guaranteed to not wrap for all lanes that are active in the vector loop.
1018+ if (!WideIV->isCanonical ()) {
1019+ VPValue *Zero = Plan.getOrAddLiveIn (ConstantInt::get (Ty, 0 ));
1020+ VPValue *One = Plan.getOrAddLiveIn (ConstantInt::get (Ty, 1 ));
1021+ auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe (
1022+ nullptr , Zero, One, WideIV->getVFValue (),
1023+ WideIV->getInductionDescriptor (), VPIRFlags (), WideIV->getDebugLoc ());
1024+ WidenCanIV->insertBefore (WideIV);
1025+
1026+ // Update the select to use the wide canonical IV.
1027+ auto *SelectRecipe = cast<VPSingleDefRecipe>(
1028+ FindIVPhiR->getBackedgeValue ()->getDefiningRecipe ());
1029+ if (SelectRecipe->getOperand (1 ) == WideIV)
1030+ SelectRecipe->setOperand (1 , WidenCanIV);
1031+ else if (SelectRecipe->getOperand (2 ) == WideIV)
1032+ SelectRecipe->setOperand (2 , WidenCanIV);
10251033 }
10261034
1027- // Create the new FindFirstIV reduction recipe.
1028- assert (!FindLastIVPhiR->isInLoop () && !FindLastIVPhiR->isOrdered ());
1029- ReductionStyle Style = RdxUnordered{FindLastIVPhiR->getVFScaleFactor ()};
1030- auto *FindFirstIVPhiR =
1031- new VPReductionPHIRecipe (nullptr , NewKind, *NewSentinel, Style,
1032- FindLastIVPhiR->hasUsesOutsideReductionChain ());
1033- FindFirstIVPhiR->addOperand (FindLastIVPhiR->getBackedgeValue ());
1035+ // Create the new UMin reduction recipe to track the minimum index.
1036+ assert (!FindIVPhiR->isInLoop () && !FindIVPhiR->isOrdered () &&
1037+ " inloop and ordered reductions not supported" );
1038+ VPValue *MaxInt =
1039+ Plan.getConstantInt (APInt::getMaxValue (Ty->getIntegerBitWidth ()));
1040+ ReductionStyle Style = RdxUnordered{FindIVPhiR->getVFScaleFactor ()};
1041+ auto *MinIdxPhiR = new VPReductionPHIRecipe (
1042+ dyn_cast_or_null<PHINode>(FindIVPhiR->getUnderlyingValue ()),
1043+ RecurKind::UMin, *MaxInt, Style,
1044+ FindIVPhiR->hasUsesOutsideReductionChain ());
1045+ MinIdxPhiR->addOperand (FindIVPhiR->getBackedgeValue ());
1046+ MinIdxPhiR->insertBefore (FindIVPhiR);
10341047
1035- FindFirstIVPhiR->insertBefore (FindLastIVPhiR);
10361048 VPInstruction *FindLastIVResult =
1037- findUserOf<VPInstruction::ComputeFindIVResult>(FindLastIVPhiR);
1038- FindLastIVPhiR->replaceAllUsesWith (FindFirstIVPhiR);
1039- FindLastIVResult->setOperand (2 , NewSentinel);
1040- return FindFirstIVPhiR;
1049+ findUserOf<VPInstruction::ComputeFindIVResult>(FindIVPhiR);
1050+ VPInstruction *MinMaxresult =
1051+ findUserOf<VPInstruction::ComputeReductionResult>(MinMaxPhiR);
1052+ MinMaxresult->moveBefore (*FindLastIVResult->getParent (),
1053+ FindLastIVResult->getIterator ());
1054+
1055+ // The reduction using MinMaxPhiR needs adjusting to compute the correct
1056+ // result:
1057+ // 1. We need to find the first canonical IV for which the condition based on
1058+ // the
1059+ // min/max recurrence is true,
1060+ // 2. Compare the partial min/max reduction result to its final value and,
1061+ // 3. Select the lanes of the partial UMin reduction of the canonical wide IV
1062+ // which
1063+ // correspond to the lanes matching the min/max reduction result.
1064+ // 4. Scale the final select canonical IV back to the original IV using
1065+ // VPDerivedIVRecipe.
1066+ // 5. If the minimum value matches the start value, the condition in the loop
1067+ // was never true, return the start value in that case.
1068+ //
1069+ // The original reductions need adjusting:
1070+ // For example, this transforms
1071+ // vp<%min.result> = compute-reduction-result ir<%min.val>,
1072+ // ir<%min.val.next>
1073+ // vp<%find.iv.result = compute-find-iv-result ir<%min.idx>, ir<0>,
1074+ // SENTINEL, vp<%min.idx.next>
1075+ //
1076+ // into:
1077+ // vp<%min.result> = compute-reduction-result ir<%min.val>, ir<%min.val.next>
1078+ // vp<%final.min.cmp> = icmp eq ir<%min.val.next>, vp<%min.result>
1079+ // vp<%final.min.iv> = select vp<%final.min.cmp>, ir<%min.idx.next>, ir<-1>
1080+ // vp<%13> = compute-reduction-result ir<%min.idx>, vp<%final.min.iv>
1081+ // vp<%scaled.result.iv> = DERIVED-IV ir<20> + vp<%13> * ir<1>
1082+ // vp<%threshold.cmp> = icmp slt vp<%min.result>, ir<0>
1083+ // vp<%final.result> = select vp<%threshold.cmp>, vp<%scaled.result.iv>,
1084+ // ir<%original.start>
1085+
1086+ VPBuilder Builder (FindLastIVResult);
1087+ VPValue *MinMaxExiting = MinMaxResult->getOperand (1 );
1088+ auto *FinalMinMaxCmp =
1089+ Builder.createICmp (CmpInst::ICMP_EQ, MinMaxExiting, MinMaxResult);
1090+ VPValue *LastIVExiting = FindLastIVResult->getOperand (3 );
1091+ auto *FinalIVSelect =
1092+ Builder.createSelect (FinalMinMaxCmp, LastIVExiting, MaxInt);
1093+ VPSingleDefRecipe *FinalResult = Builder.createNaryOp (
1094+ VPInstruction::ComputeReductionResult, {MinIdxPhiR, FinalIVSelect}, {},
1095+ FindLastIVResult->getDebugLoc ());
1096+
1097+ // If we used a new wide canonical IV convert the reduction result back to the
1098+ // original IV scale before the final select.
1099+ if (!WideIV->isCanonical ()) {
1100+ auto *DerivedIVRecipe =
1101+ new VPDerivedIVRecipe (InductionDescriptor::IK_IntInduction,
1102+ nullptr , // No FPBinOp for integer induction
1103+ WideIV->getStartValue (), FinalResult,
1104+ WideIV->getStepValue (), " derived.iv.result" );
1105+ DerivedIVRecipe->insertBefore (&*Builder.getInsertPoint ());
1106+ FinalResult = DerivedIVRecipe;
1107+ }
1108+
1109+ auto GetPred = [&]() {
1110+ switch (MinMaxPhiR->getRecurrenceKind ()) {
1111+ case RecurKind::UMin:
1112+ return CmpInst::ICMP_ULT;
1113+ case RecurKind::SMin:
1114+ return CmpInst::ICMP_SLT;
1115+ case RecurKind::UMax:
1116+ return CmpInst::ICMP_UGT;
1117+ case RecurKind::SMax:
1118+ return CmpInst::ICMP_SGT;
1119+ default :
1120+ llvm_unreachable (" " );
1121+ }
1122+ };
1123+ // If the final min/max value matches the start value, the condition in the
1124+ // loop was always false, i.e. no induction value has been selected. If that's
1125+ // the case, use the original start value.
1126+ VPValue *MinMaxLT =
1127+ Builder.createICmp (GetPred (), MinMaxresult, MinMaxPhiR->getStartValue ());
1128+ VPValue *Res = Builder.createSelect (MinMaxLT, FinalResult,
1129+ FindLastIVResult->getOperand (1 ));
1130+ FindIVPhiR->replaceAllUsesWith (MinIdxPhiR);
1131+ FindLastIVResult->replaceAllUsesWith (Res);
1132+ return true ;
10411133}
10421134
1043- bool VPlanTransforms::handleMultiUseReductions (VPlan &Plan, ScalarEvolution &SE,
1044- const Loop *L) {
1135+ bool VPlanTransforms::handleMultiUseReductions (VPlan &Plan) {
10451136 for (auto &PhiR : make_early_inc_range (
10461137 Plan.getVectorLoopRegion ()->getEntryBasicBlock ()->phis ())) {
10471138 auto *MinMaxPhiR = dyn_cast<VPReductionPHIRecipe>(&PhiR);
@@ -1151,13 +1242,14 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
11511242 if (!IsValidPredicate)
11521243 return false ;
11531244
1154- // For strict predicates, transform try to convert FindLastIV to
1155- // FindFirstIV.
1245+ // For strict predicates, use a UMin reduction to find the minimum index.
1246+ // Canonical IVs (0, 1, 2, ...) are guaranteed not to wrap in the vector
1247+ // loop, so UMin can always be used.
11561248 bool IsStrictPredicate = ICmpInst::isLT (Pred) || ICmpInst::isGT (Pred);
11571249 if (IsStrictPredicate) {
1158- FindIVPhiR = tryConvertToFindFirstIV (Plan, FindIVPhiR, IVOp, SE, L);
1159- if (!FindIVPhiR)
1160- return false ;
1250+ return handleStrictArgMinArgMax (Plan, MinMaxPhiR, FindIVPhiR,
1251+ cast<VPWidenIntOrFpInductionRecipe>(IVOp),
1252+ MinMaxResult) ;
11611253 }
11621254
11631255 // The reduction using MinMaxPhiR needs adjusting to compute the correct
0 commit comments