diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp | 585 |
1 files changed, 312 insertions, 273 deletions
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 9f0ab9103d42..5bb1d54d7d12 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -64,7 +64,6 @@ #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -126,16 +125,16 @@ struct AllocaDerivedValueTracker { switch (I->getOpcode()) { case Instruction::Call: case Instruction::Invoke: { - CallSite CS(I); + auto &CB = cast<CallBase>(*I); // If the alloca-derived argument is passed byval it is not an escape // point, or a use of an alloca. Calling with byval copies the contents // of the alloca into argument registers or stack slots, which exist // beyond the lifetime of the current frame. - if (CS.isArgOperand(U) && CS.isByValArgument(CS.getArgumentNo(U))) + if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U))) continue; bool IsNocapture = - CS.isDataOperand(U) && CS.doesNotCapture(CS.getDataOperandNo(U)); - callUsesLocalStack(CS, IsNocapture); + CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U)); + callUsesLocalStack(CB, IsNocapture); if (IsNocapture) { // If the alloca-derived argument is passed in as nocapture, then it // can't propagate to the call's return. That would be capturing. @@ -168,17 +167,17 @@ struct AllocaDerivedValueTracker { } } - void callUsesLocalStack(CallSite CS, bool IsNocapture) { + void callUsesLocalStack(CallBase &CB, bool IsNocapture) { // Add it to the list of alloca users. - AllocaUsers.insert(CS.getInstruction()); + AllocaUsers.insert(&CB); // If it's nocapture then it can't capture this alloca. if (IsNocapture) return; // If it can write to memory, it can leak the alloca value. - if (!CS.onlyReadsMemory()) - EscapePoints.insert(CS.getInstruction()); + if (!CB.onlyReadsMemory()) + EscapePoints.insert(&CB); } SmallPtrSet<Instruction *, 32> AllocaUsers; @@ -342,7 +341,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { const DataLayout &DL = L->getModule()->getDataLayout(); if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(), - MaybeAlign(L->getAlignment()), DL, L)) + L->getAlign(), DL, L)) return false; } } @@ -355,89 +354,23 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { return !is_contained(I->operands(), CI); } -/// Return true if the specified value is the same when the return would exit -/// as it was when the initial iteration of the recursive function was executed. -/// -/// We currently handle static constants and arguments that are not modified as -/// part of the recursion. -static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) { - if (isa<Constant>(V)) return true; // Static constants are always dyn consts - - // Check to see if this is an immutable argument, if so, the value - // will be available to initialize the accumulator. - if (Argument *Arg = dyn_cast<Argument>(V)) { - // Figure out which argument number this is... - unsigned ArgNo = 0; - Function *F = CI->getParent()->getParent(); - for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) - ++ArgNo; - - // If we are passing this argument into call as the corresponding - // argument operand, then the argument is dynamically constant. - // Otherwise, we cannot transform this function safely. - if (CI->getArgOperand(ArgNo) == Arg) - return true; - } - - // Switch cases are always constant integers. If the value is being switched - // on and the return is only reachable from one of its cases, it's - // effectively constant. - if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor()) - if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator())) - if (SI->getCondition() == V) - return SI->getDefaultDest() != RI->getParent(); - - // Not a constant or immutable argument, we can't safely transform. - return false; -} - -/// Check to see if the function containing the specified tail call consistently -/// returns the same runtime-constant value at all exit points except for -/// IgnoreRI. If so, return the returned value. -static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { - Function *F = CI->getParent()->getParent(); - Value *ReturnedValue = nullptr; - - for (BasicBlock &BBI : *F) { - ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()); - if (RI == nullptr || RI == IgnoreRI) continue; - - // We can only perform this transformation if the value returned is - // evaluatable at the start of the initial invocation of the function, - // instead of at the end of the evaluation. - // - Value *RetOp = RI->getOperand(0); - if (!isDynamicConstant(RetOp, CI, RI)) - return nullptr; - - if (ReturnedValue && RetOp != ReturnedValue) - return nullptr; // Cannot transform if differing values are returned. - ReturnedValue = RetOp; - } - return ReturnedValue; -} +static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { + if (!I->isAssociative() || !I->isCommutative()) + return false; -/// If the specified instruction can be transformed using accumulator recursion -/// elimination, return the constant which is the start of the accumulator -/// value. Otherwise return null. -static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { - if (!I->isAssociative() || !I->isCommutative()) return nullptr; assert(I->getNumOperands() == 2 && "Associative/commutative operations should have 2 args!"); // Exactly one operand should be the result of the call instruction. if ((I->getOperand(0) == CI && I->getOperand(1) == CI) || (I->getOperand(0) != CI && I->getOperand(1) != CI)) - return nullptr; + return false; // The only user of this instruction we allow is a single return instruction. if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back())) - return nullptr; + return false; - // Ok, now we have to check all of the other return instructions in this - // function. If they return non-constants or differing values, then we cannot - // transform the function safely. - return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI); + return true; } static Instruction *firstNonDbg(BasicBlock::iterator I) { @@ -446,11 +379,73 @@ static Instruction *firstNonDbg(BasicBlock::iterator I) { return &*I; } -static CallInst *findTRECandidate(Instruction *TI, - bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI) { +namespace { +class TailRecursionEliminator { + Function &F; + const TargetTransformInfo *TTI; + AliasAnalysis *AA; + OptimizationRemarkEmitter *ORE; + DomTreeUpdater &DTU; + + // The below are shared state we want to have available when eliminating any + // calls in the function. There values should be populated by + // createTailRecurseLoopHeader the first time we find a call we can eliminate. + BasicBlock *HeaderBB = nullptr; + SmallVector<PHINode *, 8> ArgumentPHIs; + bool RemovableCallsMustBeMarkedTail = false; + + // PHI node to store our return value. + PHINode *RetPN = nullptr; + + // i1 PHI node to track if we have a valid return value stored in RetPN. + PHINode *RetKnownPN = nullptr; + + // Vector of select instructions we insereted. These selects use RetKnownPN + // to either propagate RetPN or select a new return value. + SmallVector<SelectInst *, 8> RetSelects; + + // The below are shared state needed when performing accumulator recursion. + // There values should be populated by insertAccumulator the first time we + // find an elimination that requires an accumulator. + + // PHI node to store our current accumulated value. + PHINode *AccPN = nullptr; + + // The instruction doing the accumulating. + Instruction *AccumulatorRecursionInstr = nullptr; + + TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) + : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {} + + CallInst *findTRECandidate(Instruction *TI, + bool CannotTailCallElimCallsMarkedTail); + + void createTailRecurseLoopHeader(CallInst *CI); + + void insertAccumulator(Instruction *AccRecInstr); + + bool eliminateCall(CallInst *CI); + + bool foldReturnAndProcessPred(ReturnInst *Ret, + bool CannotTailCallElimCallsMarkedTail); + + bool processReturningBlock(ReturnInst *Ret, + bool CannotTailCallElimCallsMarkedTail); + + void cleanupAndFinalize(); + +public: + static bool eliminate(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU); +}; +} // namespace + +CallInst *TailRecursionEliminator::findTRECandidate( + Instruction *TI, bool CannotTailCallElimCallsMarkedTail) { BasicBlock *BB = TI->getParent(); - Function *F = BB->getParent(); if (&BB->front() == TI) // Make sure there is something before the terminator. return nullptr; @@ -461,7 +456,7 @@ static CallInst *findTRECandidate(Instruction *TI, BasicBlock::iterator BBI(TI); while (true) { CI = dyn_cast<CallInst>(BBI); - if (CI && CI->getCalledFunction() == F) + if (CI && CI->getCalledFunction() == &F) break; if (BBI == BB->begin()) @@ -478,16 +473,14 @@ static CallInst *findTRECandidate(Instruction *TI, // double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call // and disable this xform in this case, because the code generator will // lower the call to fabs into inline code. - if (BB == &F->getEntryBlock() && + if (BB == &F.getEntryBlock() && firstNonDbg(BB->front().getIterator()) == CI && firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() && !TTI->isLoweredToCall(CI->getCalledFunction())) { // A single-block function with just a call and a return. Check that // the arguments match. - CallSite::arg_iterator I = CallSite(CI).arg_begin(), - E = CallSite(CI).arg_end(); - Function::arg_iterator FI = F->arg_begin(), - FE = F->arg_end(); + auto I = CI->arg_begin(), E = CI->arg_end(); + Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end(); for (; I != E && FI != FE; ++I, ++FI) if (*I != &*FI) break; if (I == E && FI == FE) @@ -497,27 +490,106 @@ static CallInst *findTRECandidate(Instruction *TI, return CI; } -static bool eliminateRecursiveTailCall( - CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { - // If we are introducing accumulator recursion to eliminate operations after - // the call instruction that are both associative and commutative, the initial - // value for the accumulator is placed in this variable. If this value is set - // then we actually perform accumulator recursion elimination instead of - // simple tail recursion elimination. If the operation is an LLVM instruction - // (eg: "add") then it is recorded in AccumulatorRecursionInstr. If not, then - // we are handling the case when the return instruction returns a constant C - // which is different to the constant returned by other return instructions - // (which is recorded in AccumulatorRecursionEliminationInitVal). This is a - // special case of accumulator recursion, the operation being "return C". - Value *AccumulatorRecursionEliminationInitVal = nullptr; - Instruction *AccumulatorRecursionInstr = nullptr; +void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { + HeaderBB = &F.getEntryBlock(); + BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB); + NewEntry->takeName(HeaderBB); + HeaderBB->setName("tailrecurse"); + BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry); + BI->setDebugLoc(CI->getDebugLoc()); + + // If this function has self recursive calls in the tail position where some + // are marked tail and some are not, only transform one flavor or another. + // We have to choose whether we move allocas in the entry block to the new + // entry block or not, so we can't make a good choice for both. We make this + // decision here based on whether the first call we found to remove is + // marked tail. + // NOTE: We could do slightly better here in the case that the function has + // no entry block allocas. + RemovableCallsMustBeMarkedTail = CI->isTailCall(); + + // If this tail call is marked 'tail' and if there are any allocas in the + // entry block, move them up to the new entry block. + if (RemovableCallsMustBeMarkedTail) + // Move all fixed sized allocas from HeaderBB to NewEntry. + for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(), + NEBI = NewEntry->begin(); + OEBI != E;) + if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) + if (isa<ConstantInt>(AI->getArraySize())) + AI->moveBefore(&*NEBI); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = &HeaderBB->front(); + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { + PHINode *PN = + PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos); + I->replaceAllUsesWith(PN); // Everyone use the PHI node now! + PN->addIncoming(&*I, NewEntry); + ArgumentPHIs.push_back(PN); + } + + // If the function doen't return void, create the RetPN and RetKnownPN PHI + // nodes to track our return value. We initialize RetPN with undef and + // RetKnownPN with false since we can't know our return value at function + // entry. + Type *RetType = F.getReturnType(); + if (!RetType->isVoidTy()) { + Type *BoolType = Type::getInt1Ty(F.getContext()); + RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos); + RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos); + + RetPN->addIncoming(UndefValue::get(RetType), NewEntry); + RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry); + } + + // The entry block was changed from HeaderBB to NewEntry. + // The forward DominatorTree needs to be recalculated when the EntryBB is + // changed. In this corner-case we recalculate the entire tree. + DTU.recalculate(*NewEntry->getParent()); +} + +void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { + assert(!AccPN && "Trying to insert multiple accumulators"); + + AccumulatorRecursionInstr = AccRecInstr; + + // Start by inserting a new PHI node for the accumulator. + pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB); + AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1, + "accumulator.tr", &HeaderBB->front()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the identity constant for + // the accumulation operation. For any other existing branches to this block + // (due to other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to HeaderBB yet, + // it will not show up as a predecessor. + for (pred_iterator PI = PB; PI != PE; ++PI) { + BasicBlock *P = *PI; + if (P == &F.getEntryBlock()) { + Constant *Identity = ConstantExpr::getBinOpIdentity( + AccRecInstr->getOpcode(), AccRecInstr->getType()); + AccPN->addIncoming(Identity, P); + } else { + AccPN->addIncoming(AccPN, P); + } + } + + ++NumAccumAdded; +} + +bool TailRecursionEliminator::eliminateCall(CallInst *CI) { + ReturnInst *Ret = cast<ReturnInst>(CI->getParent()->getTerminator()); // Ok, we found a potential tail call. We can currently only transform the // tail call if all of the instructions between the call and the return are // movable to above the call itself, leaving the call next to the return. // Check that this is the case now. + Instruction *AccRecInstr = nullptr; BasicBlock::iterator BBI(CI); for (++BBI; &*BBI != Ret; ++BBI) { if (canMoveAboveCall(&*BBI, CI, AA)) @@ -526,39 +598,16 @@ static bool eliminateRecursiveTailCall( // If we can't move the instruction above the call, it might be because it // is an associative and commutative operation that could be transformed // using accumulator recursion elimination. Check to see if this is the - // case, and if so, remember the initial accumulator value for later. - if ((AccumulatorRecursionEliminationInitVal = - canTransformAccumulatorRecursion(&*BBI, CI))) { - // Yes, this is accumulator recursion. Remember which instruction - // accumulates. - AccumulatorRecursionInstr = &*BBI; - } else { - return false; // Otherwise, we cannot eliminate the tail recursion! - } - } + // case, and if so, remember which instruction accumulates for later. + if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI)) + return false; // We cannot eliminate the tail recursion! - // We can only transform call/return pairs that either ignore the return value - // of the call and return void, ignore the value of the call and return a - // constant, return the value returned by the tail call, or that are being - // accumulator recursion variable eliminated. - if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI && - !isa<UndefValue>(Ret->getReturnValue()) && - AccumulatorRecursionEliminationInitVal == nullptr && - !getCommonReturnValue(nullptr, CI)) { - // One case remains that we are able to handle: the current return - // instruction returns a constant, and all other return instructions - // return a different constant. - if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret)) - return false; // Current return instruction does not return a constant. - // Check that all other return instructions return a common constant. If - // so, record it in AccumulatorRecursionEliminationInitVal. - AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI); - if (!AccumulatorRecursionEliminationInitVal) - return false; + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccRecInstr = &*BBI; } BasicBlock *BB = Ret->getParent(); - Function *F = BB->getParent(); using namespace ore; ORE->emit([&]() { @@ -568,51 +617,10 @@ static bool eliminateRecursiveTailCall( // OK! We can transform this tail call. If this is the first one found, // create the new entry block, allowing us to branch back to the old entry. - if (!OldEntry) { - OldEntry = &F->getEntryBlock(); - BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry); - NewEntry->takeName(OldEntry); - OldEntry->setName("tailrecurse"); - BranchInst *BI = BranchInst::Create(OldEntry, NewEntry); - BI->setDebugLoc(CI->getDebugLoc()); - - // If this tail call is marked 'tail' and if there are any allocas in the - // entry block, move them up to the new entry block. - TailCallsAreMarkedTail = CI->isTailCall(); - if (TailCallsAreMarkedTail) - // Move all fixed sized allocas from OldEntry to NewEntry. - for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), - NEBI = NewEntry->begin(); OEBI != E; ) - if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) - if (isa<ConstantInt>(AI->getArraySize())) - AI->moveBefore(&*NEBI); - - // Now that we have created a new block, which jumps to the entry - // block, insert a PHI node for each argument of the function. - // For now, we initialize each PHI to only have the real arguments - // which are passed in. - Instruction *InsertPos = &OldEntry->front(); - for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); - I != E; ++I) { - PHINode *PN = PHINode::Create(I->getType(), 2, - I->getName() + ".tr", InsertPos); - I->replaceAllUsesWith(PN); // Everyone use the PHI node now! - PN->addIncoming(&*I, NewEntry); - ArgumentPHIs.push_back(PN); - } - // The entry block was changed from OldEntry to NewEntry. - // The forward DominatorTree needs to be recalculated when the EntryBB is - // changed. In this corner-case we recalculate the entire tree. - DTU.recalculate(*NewEntry->getParent()); - } + if (!HeaderBB) + createTailRecurseLoopHeader(CI); - // If this function has self recursive calls in the tail position where some - // are marked tail and some are not, only transform one flavor or another. We - // have to choose whether we move allocas in the entry block to the new entry - // block or not, so we can't make a good choice for both. NOTE: We could do - // slightly better here in the case that the function has no entry block - // allocas. - if (TailCallsAreMarkedTail && !CI->isTailCall()) + if (RemovableCallsMustBeMarkedTail && !CI->isTailCall()) return false; // Ok, now that we know we have a pseudo-entry block WITH all of the @@ -621,74 +629,53 @@ static bool eliminateRecursiveTailCall( for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB); - // If we are introducing an accumulator variable to eliminate the recursion, - // do so now. Note that we _know_ that no subsequent tail recursion - // eliminations will happen on this function because of the way the - // accumulator recursion predicate is set up. - // - if (AccumulatorRecursionEliminationInitVal) { - Instruction *AccRecInstr = AccumulatorRecursionInstr; - // Start by inserting a new PHI node for the accumulator. - pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry); - PHINode *AccPN = PHINode::Create( - AccumulatorRecursionEliminationInitVal->getType(), - std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front()); - - // Loop over all of the predecessors of the tail recursion block. For the - // real entry into the function we seed the PHI with the initial value, - // computed earlier. For any other existing branches to this block (due to - // other tail recursions eliminated) the accumulator is not modified. - // Because we haven't added the branch in the current block to OldEntry yet, - // it will not show up as a predecessor. - for (pred_iterator PI = PB; PI != PE; ++PI) { - BasicBlock *P = *PI; - if (P == &F->getEntryBlock()) - AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); - else - AccPN->addIncoming(AccPN, P); - } + if (AccRecInstr) { + insertAccumulator(AccRecInstr); - if (AccRecInstr) { - // Add an incoming argument for the current block, which is computed by - // our associative and commutative accumulator instruction. - AccPN->addIncoming(AccRecInstr, BB); + // Rewrite the accumulator recursion instruction so that it does not use + // the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + } - // Next, rewrite the accumulator recursion instruction so that it does not - // use the result of the call anymore, instead, use the PHI node we just - // inserted. - AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + // Update our return value tracking + if (RetPN) { + if (Ret->getReturnValue() == CI || AccRecInstr) { + // Defer selecting a return value + RetPN->addIncoming(RetPN, BB); + RetKnownPN->addIncoming(RetKnownPN, BB); } else { - // Add an incoming argument for the current block, which is just the - // constant returned by the current return instruction. - AccPN->addIncoming(Ret->getReturnValue(), BB); + // We found a return value we want to use, insert a select instruction to + // select it if we don't already know what our return value will be and + // store the result in our return value PHI node. + SelectInst *SI = SelectInst::Create( + RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret); + RetSelects.push_back(SI); + + RetPN->addIncoming(SI, BB); + RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB); } - // Finally, rewrite any return instructions in the program to return the PHI - // node instead of the "initval" that they do currently. This loop will - // actually rewrite the return value we are destroying, but that's ok. - for (BasicBlock &BBI : *F) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator())) - RI->setOperand(0, AccPN); - ++NumAccumAdded; + if (AccPN) + AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB); } // Now that all of the PHI nodes are in place, remove the call and // ret instructions, replacing them with an unconditional branch. - BranchInst *NewBI = BranchInst::Create(OldEntry, Ret); + BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret); NewBI->setDebugLoc(CI->getDebugLoc()); BB->getInstList().erase(Ret); // Remove return. BB->getInstList().erase(CI); // Remove call. - DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}}); + DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}}); ++NumEliminated; return true; } -static bool foldReturnAndProcessPred( - BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { +bool TailRecursionEliminator::foldReturnAndProcessPred( + ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { + BasicBlock *BB = Ret->getParent(); + bool Change = false; // Make sure this block is a trivial return block. @@ -711,10 +698,11 @@ static bool foldReturnAndProcessPred( while (!UncondBranchPreds.empty()) { BranchInst *BI = UncondBranchPreds.pop_back_val(); BasicBlock *Pred = BI->getParent(); - if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ + if (CallInst *CI = + findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) { LLVM_DEBUG(dbgs() << "FOLDING: " << *BB << "INTO UNCOND BRANCH PRED: " << *Pred); - ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); + FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); // Cleanup: if all predecessors of BB have been eliminated by // FoldReturnIntoUncondBranch, delete it. It is important to empty it, @@ -723,8 +711,7 @@ static bool foldReturnAndProcessPred( if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) DTU.deleteBB(BB); - eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE, DTU); + eliminateCall(CI); ++NumRetDuped; Change = true; } @@ -733,23 +720,92 @@ static bool foldReturnAndProcessPred( return Change; } -static bool processReturningBlock( - ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { - CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); +bool TailRecursionEliminator::processReturningBlock( + ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { + CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail); if (!CI) return false; - return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE, DTU); + return eliminateCall(CI); +} + +void TailRecursionEliminator::cleanupAndFinalize() { + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + for (PHINode *PN : ArgumentPHIs) { + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } + + if (RetPN) { + if (RetSelects.empty()) { + // If we didn't insert any select instructions, then we know we didn't + // store a return value and we can remove the PHI nodes we inserted. + RetPN->dropAllReferences(); + RetPN->eraseFromParent(); + + RetKnownPN->dropAllReferences(); + RetKnownPN->eraseFromParent(); + + if (AccPN) { + // We need to insert a copy of our accumulator instruction before any + // return in the function, and return its result instead. + Instruction *AccRecInstr = AccumulatorRecursionInstr; + for (BasicBlock &BB : F) { + ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator()); + if (!RI) + continue; + + Instruction *AccRecInstrNew = AccRecInstr->clone(); + AccRecInstrNew->setName("accumulator.ret.tr"); + AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, + RI->getOperand(0)); + AccRecInstrNew->insertBefore(RI); + RI->setOperand(0, AccRecInstrNew); + } + } + } else { + // We need to insert a select instruction before any return left in the + // function to select our stored return value if we have one. + for (BasicBlock &BB : F) { + ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator()); + if (!RI) + continue; + + SelectInst *SI = SelectInst::Create( + RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI); + RetSelects.push_back(SI); + RI->setOperand(0, SI); + } + + if (AccPN) { + // We need to insert a copy of our accumulator instruction before any + // of the selects we inserted, and select its result instead. + Instruction *AccRecInstr = AccumulatorRecursionInstr; + for (SelectInst *SI : RetSelects) { + Instruction *AccRecInstrNew = AccRecInstr->clone(); + AccRecInstrNew->setName("accumulator.ret.tr"); + AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, + SI->getFalseValue()); + AccRecInstrNew->insertBefore(SI); + SI->setFalseValue(AccRecInstrNew); + } + } + } + } } -static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, - AliasAnalysis *AA, - OptimizationRemarkEmitter *ORE, - DomTreeUpdater &DTU) { +bool TailRecursionEliminator::eliminate(Function &F, + const TargetTransformInfo *TTI, + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) { if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") return false; @@ -762,17 +818,15 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, // If this function is a varargs function, we won't be able to PHI the args // right, so don't even try to convert it... if (F.getFunctionType()->isVarArg()) - return false; - - BasicBlock *OldEntry = nullptr; - bool TailCallsAreMarkedTail = false; - SmallVector<PHINode*, 8> ArgumentPHIs; + return MadeChange; // If false, we cannot perform TRE on tail calls marked with the 'tail' // attribute, because doing so would cause the stack size to increase (real // TRE would deallocate variable sized allocas, TRE doesn't). bool CanTRETailMarkedCall = canTRE(F); + TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU); + // Change any tail recursive calls to loops. // // FIXME: The code generator produces really bad code when an 'escaping @@ -782,29 +836,14 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB. if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { - bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall, - TTI, AA, ORE, DTU); + bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall); if (!Change && BB->getFirstNonPHIOrDbg() == Ret) - Change = foldReturnAndProcessPred( - BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, - !CanTRETailMarkedCall, TTI, AA, ORE, DTU); + Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall); MadeChange |= Change; } } - // If we eliminated any tail recursions, it's possible that we inserted some - // silly PHI nodes which just merge an initial value (the incoming operand) - // with themselves. Check to see if we did and clean up our mess if so. This - // occurs when a function passes an argument straight through to its tail - // call. - for (PHINode *PN : ArgumentPHIs) { - // If the PHI Node is a dynamic constant, replace it with the value it is. - if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { - PN->replaceAllUsesWith(PNV); - PN->eraseFromParent(); - } - } + TRE.cleanupAndFinalize(); return MadeChange; } @@ -838,7 +877,7 @@ struct TailCallElim : public FunctionPass { // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - return eliminateTailRecursion( + return TailRecursionEliminator::eliminate( F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), &getAnalysis<AAResultsWrapperPass>().getAAResults(), &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU); @@ -871,7 +910,7 @@ PreservedAnalyses TailCallElimPass::run(Function &F, // UpdateStrategy based on some test results. It is feasible to switch the // UpdateStrategy to Lazy if we find it profitable later. DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); - bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU); + bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU); if (!Changed) return PreservedAnalyses::all(); |